From 36d4b58e492abda076425dce906dafce497c6924 Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Thu, 16 Oct 2025 13:56:14 -0700 Subject: [PATCH] Support initializing StaticAttentionIOManager from any module with StaticAttention inside Summary: Sometimes the `ModelArgs` object is not available, for example: https://www.internalfb.com/code/fbsource/[3243b3d06108fff9ab5a9062905844454e600e1d]/fbcode/assistant/multimodal/llm_mm_aligner/lib/models/minymal/model_export/export_vlm.py?lines=161-164 This diff provides an alternative way to initialize the IO manager from any module that contains `StaticAttention`. The assumption is that the RoPE setting of those attentions modules are all the same. Differential Revision: D84854882 --- examples/models/llama/static_attention.py | 103 ++++++++++++++++++---- 1 file changed, 88 insertions(+), 15 deletions(-) diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 1880a09f5c6..4d5b9c1da57 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -1,3 +1,4 @@ +import copy import logging from abc import ABC, abstractmethod from collections import defaultdict, deque @@ -239,7 +240,7 @@ def __str__(self): def __init__( self, - config: ModelArgs, + config_or_model: Union[ModelArgs, nn.Module], input_len: int, cache_lens: Union[int, List[int]], batch_size: int = 1, @@ -248,8 +249,10 @@ def __init__( mask_val: float = float("-inf"), ): if isinstance(cache_lens, int): - cache_lens = [cache_lens] * config.n_layers - assert len(cache_lens) == config.n_layers + cache_lens_dict = defaultdict(lambda x=cache_lens: x) + cache_lens = [cache_lens] + else: + cache_lens_dict = dict(enumerate(cache_lens)) self._masks = { cl: StaticAttentionMask( @@ -258,6 +261,24 @@ def __init__( for cl in set(cache_lens) } + if isinstance(config_or_model, ModelArgs): + self._from_config(config_or_model, cache_lens_dict, batch_size, dtype) + else: + self._from_model(config_or_model, cache_lens_dict, batch_size, dtype) + + self.input_len = input_len + self.style = style + self.mask_val = mask_val + self.pos = 0 + self.cache_full = False + + def _from_config( + self, + config: ModelArgs, + cache_lens: Dict[int, int], + batch_size: int, + dtype: torch.dtype, + ): rope = Rope(config) freqs = rope.get_freqs(None, config.max_context_len) self.freqs_cos = freqs[0].to(dtype) @@ -311,13 +332,63 @@ def __init__( if cache_lens[layer_id] > 0 } - self.config = config - self.input_len = input_len - self.cache_lens = cache_lens - self.style = style - self.mask_val = mask_val - self.pos = 0 - self.cache_full = False + self.generate_full_logits = config.generate_full_logits + + def _from_model( + self, + config: nn.Module, + cache_lens: Dict[int, int], + batch_size: int, + dtype: torch.dtype, + ): + static_attentions = [] + for module in config.modules(): + if isinstance(module, StaticAttention): + static_attentions.append(module) + + if not static_attentions: + raise ValueError("No StaticAttention modules found in the provided module") + + config = copy.copy(static_attentions[0].rope.config) + config.use_hf_rope = static_attentions[0].rope.use_hf_rope + rope = Rope(config) + freqs = rope.get_freqs(None, config.max_context_len) + self.freqs_cos = freqs[0].to(dtype) + self.freqs_sin = freqs[1].to(dtype) + + self.k_caches = {} + self.v_caches = {} + for attn in static_attentions: + if attn.split_mha: + for head_id in range(attn.n_heads): + cache_key = StaticKVCache.calculate_cache_key( + attn.layer_id, head_id + ) + for cache in (self.k_caches, self.v_caches): + assert ( + cache_key not in cache + ), "Found StaticAttention modules with duplicated layer_id" + cache[cache_key] = torch.zeros( + batch_size, + cache_lens[attn.layer_id], + attn.head_dim, + dtype=dtype, + ) + else: + cache_key = StaticKVCache.calculate_cache_key(attn.layer_id, 0) + for cache in (self.k_caches, self.v_caches): + assert ( + cache_key not in cache + ), "Found StaticAttention modules with duplicated layer_id" + cache[cache_key] = torch.zeros( + batch_size, + attn.n_kv_heads, + cache_lens[attn.layer_id], + attn.head_dim, + dtype=dtype, + ) + + self.generate_full_logits = True @property def masks(self): @@ -352,13 +423,13 @@ def prefill( all_logits = None for i in range(0, tokens.size(1), self.input_len): logits = self._run_once(model, tokens[:, i : i + self.input_len])[0] - if self.config.generate_full_logits: + if self.generate_full_logits: if all_logits is None: all_logits = logits else: all_logits = torch.cat([all_logits, logits], dim=1) - if self.config.generate_full_logits: + if self.generate_full_logits: return all_logits[:, : tokens.size(1), :] return logits @@ -637,9 +708,10 @@ def _get_lookahead_position_offsets( class _Rope(nn.Module): - def __init__(self, use_hf_rope): + def __init__(self, config: ModelArgs): super().__init__() - self.use_hf_rope = use_hf_rope + self.config = config + self.use_hf_rope = config.use_hf_rope def forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor @@ -755,7 +827,8 @@ def __init__( self.v_caches = nn.ModuleList([StaticVCache(layer_id, 0)]) self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) - self.rope = _Rope(rope.params.use_hf_rope) + self.rope = _Rope(rope.params) + self.layer_id = layer_id if self.use_qk_norm: self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps)