55from hf_mem .safetensors .types import TorchDtypes , get_safetensors_dtype_bytes , torch_dtype_to_safetensors_dtype
66
77
8+ # NOTE: Only full-attention (global) layers grow the KV cache with max_model_len; meaning
9+ # that only those contribute to a KV cache that scales with the context length, whereas the sliding
10+ # window layers reuse a fixed-size buffer and are excluded from the estimation.
811def _resolve_num_attention_layers (config : Dict [str , Any ]) -> int :
9- """Return the number of full-attention (global) layers that grow the KV cache with max_model_len.
10-
11- Models with hybrid attention mix full attention and sliding window attention layers. Only the
12- full attention layers contribute to a KV cache that scales with the context length; sliding
13- window layers reuse a fixed-size buffer and are therefore excluded from the estimate.
14- """
1512 num_hidden_layers : int = config ["num_hidden_layers" ]
1613
1714 # NOTE: Gemma3-style hybrid attention: every N-th layer (0-indexed where `i % N == N-1`)
@@ -21,12 +18,11 @@ def _resolve_num_attention_layers(config: Dict[str, Any]) -> int:
2118 if "sliding_window_pattern" in config :
2219 return num_hidden_layers // config ["sliding_window_pattern" ]
2320
24- # NOTE: Some models provide an explicit list of layer types — count the non-sliding-window ones.
25- # Known string values for full attention vary by architecture.
21+ # NOTE: Some models provide an explicit list of layer types, so we need to count the non-sliding-window ones.
2622 if "layer_types" in config :
2723 return sum (1 for t in config ["layer_types" ] if t in {"attention" , "full_attention" , "global_attention" })
2824
29- # NOTE: Default — assume all layers use full attention (standard MHA / GQA without SWA).
25+ # NOTE: By default assume all layers use full attention (standard MHA / GQA without SWA).
3026 return num_hidden_layers
3127
3228
@@ -36,7 +32,6 @@ def resolve_kv_cache_dtype(
3632 metadata : SafetensorsMetadata ,
3733 model_id : str ,
3834) -> str :
39- """Resolve the effective KV cache dtype string (a `SafetensorsDtypes` value) from the CLI flag and config."""
4035 if kv_cache_dtype in {"fp8_e5m2" , "fp8_e4m3" }:
4136 return kv_cache_dtype .upper ().replace ("FP8" , "F8" ) # type: ignore[union-attr]
4237
@@ -116,18 +111,14 @@ def compute_safetensors_kv_cache_size(
116111 max_model_len : int ,
117112 batch_size : int = 1 ,
118113) -> int :
119- """Compute the KV cache memory requirement in bytes for a Safetensors model.
120-
121- Reference: https://gist.github.com/alvarobartt/1097ca1b07c66fd71470937d599c2072
122- """
123114 hidden_size : int = config ["hidden_size" ]
124115 num_attention_heads : int = config ["num_attention_heads" ]
125116
126117 # NOTE: `num_key_value_heads` defaults to `num_attention_heads` in MHA, and is explicitly
127118 # set to a smaller value in GQA / MQA
128119 num_key_value_heads : int = config .get ("num_key_value_heads" , num_attention_heads )
129120
130- # NOTE: Use head_dim directly if specified in the config; some models (e.g. Qwen3) set
121+ # NOTE: Use ` head_dim` directly if specified in the config; some models (e.g. Qwen3) set
131122 # hidden_size and num_attention_heads independently from the actual per-head size,
132123 # making the fallback `hidden_size // num_attention_heads` incorrect for those models
133124 head_dim : int = config .get ("head_dim" , hidden_size // num_attention_heads )
0 commit comments