Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,7 @@ def __init__(
device=device,
enable_memory_saver=False,
)
# Optimize mapping lookup: use __contains__ and __getitem__ on dict (no change, but next function will optimize)
self.full_attention_layer_id_mapping = {
id: i for i, id in enumerate(full_attention_layer_ids)
}
Expand All @@ -856,11 +857,14 @@ def maybe_get_custom_mem_pool(self):
return self.full_kv_pool.maybe_get_custom_mem_pool()

def _transfer_full_attention_id(self, layer_id: int):
if layer_id not in self.full_attention_layer_id_mapping:
try:
return self.full_attention_layer_id_mapping[layer_id]
except KeyError:
# Direct conversion to list is much faster than .keys(), especially for error case formatting
keys = list(self.full_attention_layer_id_mapping.keys())
raise ValueError(
f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
f"{layer_id=} not in full attention layers: {keys}"
)
return self.full_attention_layer_id_mapping[layer_id]

def get_key_buffer(self, layer_id: int):
layer_id = self._transfer_full_attention_id(layer_id)
Expand Down