Skip to content

Commit 9acd52b

Browse files
authored
Support initializing StaticAttentionIOManager from any module with StaticAttention inside
Differential Revision: D84854882 Pull Request resolved: pytorch#15206
1 parent 080cd01 commit 9acd52b

File tree

1 file changed

+88
-15
lines changed

1 file changed

+88
-15
lines changed

examples/models/llama/static_attention.py

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import logging
23
from abc import ABC, abstractmethod
34
from collections import defaultdict, deque
@@ -239,7 +240,7 @@ def __str__(self):
239240

240241
def __init__(
241242
self,
242-
config: ModelArgs,
243+
config_or_model: Union[ModelArgs, nn.Module],
243244
input_len: int,
244245
cache_lens: Union[int, List[int]],
245246
batch_size: int = 1,
@@ -248,8 +249,10 @@ def __init__(
248249
mask_val: float = float("-inf"),
249250
):
250251
if isinstance(cache_lens, int):
251-
cache_lens = [cache_lens] * config.n_layers
252-
assert len(cache_lens) == config.n_layers
252+
cache_lens_dict = defaultdict(lambda x=cache_lens: x)
253+
cache_lens = [cache_lens]
254+
else:
255+
cache_lens_dict = dict(enumerate(cache_lens))
253256

254257
self._masks = {
255258
cl: StaticAttentionMask(
@@ -258,6 +261,24 @@ def __init__(
258261
for cl in set(cache_lens)
259262
}
260263

264+
if isinstance(config_or_model, ModelArgs):
265+
self._from_config(config_or_model, cache_lens_dict, batch_size, dtype)
266+
else:
267+
self._from_model(config_or_model, cache_lens_dict, batch_size, dtype)
268+
269+
self.input_len = input_len
270+
self.style = style
271+
self.mask_val = mask_val
272+
self.pos = 0
273+
self.cache_full = False
274+
275+
def _from_config(
276+
self,
277+
config: ModelArgs,
278+
cache_lens: Dict[int, int],
279+
batch_size: int,
280+
dtype: torch.dtype,
281+
):
261282
rope = Rope(config)
262283
freqs = rope.get_freqs(None, config.max_context_len)
263284
self.freqs_cos = freqs[0].to(dtype)
@@ -311,13 +332,63 @@ def __init__(
311332
if cache_lens[layer_id] > 0
312333
}
313334

314-
self.config = config
315-
self.input_len = input_len
316-
self.cache_lens = cache_lens
317-
self.style = style
318-
self.mask_val = mask_val
319-
self.pos = 0
320-
self.cache_full = False
335+
self.generate_full_logits = config.generate_full_logits
336+
337+
def _from_model(
338+
self,
339+
config: nn.Module,
340+
cache_lens: Dict[int, int],
341+
batch_size: int,
342+
dtype: torch.dtype,
343+
):
344+
static_attentions = []
345+
for module in config.modules():
346+
if isinstance(module, StaticAttention):
347+
static_attentions.append(module)
348+
349+
if not static_attentions:
350+
raise ValueError("No StaticAttention modules found in the provided module")
351+
352+
config = copy.copy(static_attentions[0].rope.config)
353+
config.use_hf_rope = static_attentions[0].rope.use_hf_rope
354+
rope = Rope(config)
355+
freqs = rope.get_freqs(None, config.max_context_len)
356+
self.freqs_cos = freqs[0].to(dtype)
357+
self.freqs_sin = freqs[1].to(dtype)
358+
359+
self.k_caches = {}
360+
self.v_caches = {}
361+
for attn in static_attentions:
362+
if attn.split_mha:
363+
for head_id in range(attn.n_heads):
364+
cache_key = StaticKVCache.calculate_cache_key(
365+
attn.layer_id, head_id
366+
)
367+
for cache in (self.k_caches, self.v_caches):
368+
assert (
369+
cache_key not in cache
370+
), "Found StaticAttention modules with duplicated layer_id"
371+
cache[cache_key] = torch.zeros(
372+
batch_size,
373+
cache_lens[attn.layer_id],
374+
attn.head_dim,
375+
dtype=dtype,
376+
)
377+
else:
378+
cache_key = StaticKVCache.calculate_cache_key(attn.layer_id, 0)
379+
for cache in (self.k_caches, self.v_caches):
380+
assert (
381+
cache_key not in cache
382+
), "Found StaticAttention modules with duplicated layer_id"
383+
cache[cache_key] = torch.zeros(
384+
batch_size,
385+
attn.n_kv_heads,
386+
cache_lens[attn.layer_id],
387+
attn.head_dim,
388+
dtype=dtype,
389+
)
390+
391+
self.generate_full_logits = True
321392

322393
@property
323394
def masks(self):
@@ -352,13 +423,13 @@ def prefill(
352423
all_logits = None
353424
for i in range(0, tokens.size(1), self.input_len):
354425
logits = self._run_once(model, tokens[:, i : i + self.input_len])[0]
355-
if self.config.generate_full_logits:
426+
if self.generate_full_logits:
356427
if all_logits is None:
357428
all_logits = logits
358429
else:
359430
all_logits = torch.cat([all_logits, logits], dim=1)
360431

361-
if self.config.generate_full_logits:
432+
if self.generate_full_logits:
362433
return all_logits[:, : tokens.size(1), :]
363434

364435
return logits
@@ -637,9 +708,10 @@ def _get_lookahead_position_offsets(
637708

638709

639710
class _Rope(nn.Module):
640-
def __init__(self, use_hf_rope):
711+
def __init__(self, config: ModelArgs):
641712
super().__init__()
642-
self.use_hf_rope = use_hf_rope
713+
self.config = config
714+
self.use_hf_rope = config.use_hf_rope
643715

644716
def forward(
645717
self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
@@ -755,7 +827,8 @@ def __init__(
755827
self.v_caches = nn.ModuleList([StaticVCache(layer_id, 0)])
756828

757829
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
758-
self.rope = _Rope(rope.params.use_hf_rope)
830+
self.rope = _Rope(rope.params)
831+
self.layer_id = layer_id
759832

760833
if self.use_qk_norm:
761834
self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps)

0 commit comments

Comments
 (0)