Skip to content

Commit 62c0b91

Browse files
committed
Update comments on _resolve_num_attention_layers
1 parent 49dc918 commit 62c0b91

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

src/hf_mem/safetensors/kv_cache.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,10 @@
55
from hf_mem.safetensors.types import TorchDtypes, get_safetensors_dtype_bytes, torch_dtype_to_safetensors_dtype
66

77

8+
# NOTE: Only full-attention (global) layers grow the KV cache with max_model_len; meaning
9+
# that only those contribute to a KV cache that scales with the context length, whereas the sliding
10+
# window layers reuse a fixed-size buffer and are excluded from the estimation.
811
def _resolve_num_attention_layers(config: Dict[str, Any]) -> int:
9-
"""Return the number of full-attention (global) layers that grow the KV cache with max_model_len.
10-
11-
Models with hybrid attention mix full attention and sliding window attention layers. Only the
12-
full attention layers contribute to a KV cache that scales with the context length; sliding
13-
window layers reuse a fixed-size buffer and are therefore excluded from the estimate.
14-
"""
1512
num_hidden_layers: int = config["num_hidden_layers"]
1613

1714
# NOTE: Gemma3-style hybrid attention: every N-th layer (0-indexed where `i % N == N-1`)
@@ -21,12 +18,11 @@ def _resolve_num_attention_layers(config: Dict[str, Any]) -> int:
2118
if "sliding_window_pattern" in config:
2219
return num_hidden_layers // config["sliding_window_pattern"]
2320

24-
# NOTE: Some models provide an explicit list of layer types — count the non-sliding-window ones.
25-
# Known string values for full attention vary by architecture.
21+
# NOTE: Some models provide an explicit list of layer types, so we need to count the non-sliding-window ones.
2622
if "layer_types" in config:
2723
return sum(1 for t in config["layer_types"] if t in {"attention", "full_attention", "global_attention"})
2824

29-
# NOTE: Default — assume all layers use full attention (standard MHA / GQA without SWA).
25+
# NOTE: By default assume all layers use full attention (standard MHA / GQA without SWA).
3026
return num_hidden_layers
3127

3228

@@ -36,7 +32,6 @@ def resolve_kv_cache_dtype(
3632
metadata: SafetensorsMetadata,
3733
model_id: str,
3834
) -> str:
39-
"""Resolve the effective KV cache dtype string (a `SafetensorsDtypes` value) from the CLI flag and config."""
4035
if kv_cache_dtype in {"fp8_e5m2", "fp8_e4m3"}:
4136
return kv_cache_dtype.upper().replace("FP8", "F8") # type: ignore[union-attr]
4237

@@ -116,18 +111,14 @@ def compute_safetensors_kv_cache_size(
116111
max_model_len: int,
117112
batch_size: int = 1,
118113
) -> int:
119-
"""Compute the KV cache memory requirement in bytes for a Safetensors model.
120-
121-
Reference: https://gist.github.com/alvarobartt/1097ca1b07c66fd71470937d599c2072
122-
"""
123114
hidden_size: int = config["hidden_size"]
124115
num_attention_heads: int = config["num_attention_heads"]
125116

126117
# NOTE: `num_key_value_heads` defaults to `num_attention_heads` in MHA, and is explicitly
127118
# set to a smaller value in GQA / MQA
128119
num_key_value_heads: int = config.get("num_key_value_heads", num_attention_heads)
129120

130-
# NOTE: Use head_dim directly if specified in the config; some models (e.g. Qwen3) set
121+
# NOTE: Use `head_dim` directly if specified in the config; some models (e.g. Qwen3) set
131122
# hidden_size and num_attention_heads independently from the actual per-head size,
132123
# making the fallback `hidden_size // num_attention_heads` incorrect for those models
133124
head_dim: int = config.get("head_dim", hidden_size // num_attention_heads)

0 commit comments

Comments
 (0)