diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index c468269f3ff..915d602b2e6 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -834,6 +834,7 @@ def __init__( device=device, enable_memory_saver=False, ) + # Optimize mapping lookup: use __contains__ and __getitem__ on dict (no change, but next function will optimize) self.full_attention_layer_id_mapping = { id: i for i, id in enumerate(full_attention_layer_ids) } @@ -856,11 +857,14 @@ def maybe_get_custom_mem_pool(self): return self.full_kv_pool.maybe_get_custom_mem_pool() def _transfer_full_attention_id(self, layer_id: int): - if layer_id not in self.full_attention_layer_id_mapping: + try: + return self.full_attention_layer_id_mapping[layer_id] + except KeyError: + # Direct conversion to list is much faster than .keys(), especially for error case formatting + keys = list(self.full_attention_layer_id_mapping.keys()) raise ValueError( - f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}" + f"{layer_id=} not in full attention layers: {keys}" ) - return self.full_attention_layer_id_mapping[layer_id] def get_key_buffer(self, layer_id: int): layer_id = self._transfer_full_attention_id(layer_id)