Skip to content

Commit 9cc092c

Browse files
authored
Static attention batch size > 1
Differential Revision: D81245500 Pull Request resolved: #13919
1 parent b759ae8 commit 9cc092c

File tree

2 files changed

+87
-17
lines changed

2 files changed

+87
-17
lines changed

examples/models/llama/static_attention.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ def __init__(
242242
config: ModelArgs,
243243
input_len: int,
244244
cache_lens: Union[int, List[int]],
245-
dtype=torch.float32,
245+
batch_size: int = 1,
246+
dtype: torch.dtype = torch.float32,
246247
style: str = "shift_pointer",
247248
mask_val: float = float("-inf"),
248249
):
@@ -266,15 +267,21 @@ def __init__(
266267
if split_mha:
267268
self.k_caches = {
268269
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
269-
1, cache_lens[layer_id], none_throws(config.head_dim), dtype=dtype
270+
batch_size,
271+
cache_lens[layer_id],
272+
none_throws(config.head_dim),
273+
dtype=dtype,
270274
)
271275
for layer_id in range(config.n_layers)
272276
for head_id in range(none_throws(config.n_kv_heads))
273277
if cache_lens[layer_id] > 0
274278
}
275279
self.v_caches = {
276280
StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros(
277-
1, cache_lens[layer_id], none_throws(config.head_dim), dtype=dtype
281+
batch_size,
282+
cache_lens[layer_id],
283+
none_throws(config.head_dim),
284+
dtype=dtype,
278285
)
279286
for layer_id in range(config.n_layers)
280287
for head_id in range(none_throws(config.n_kv_heads))
@@ -283,7 +290,7 @@ def __init__(
283290
else:
284291
self.k_caches = {
285292
StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros(
286-
1,
293+
batch_size,
287294
none_throws(config.n_kv_heads),
288295
cache_lens[layer_id],
289296
none_throws(config.head_dim),
@@ -293,7 +300,7 @@ def __init__(
293300
}
294301
self.v_caches = {
295302
StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros(
296-
1,
303+
batch_size,
297304
none_throws(config.n_kv_heads),
298305
cache_lens[layer_id],
299306
none_throws(config.head_dim),
@@ -323,7 +330,7 @@ def reset(self):
323330
def prefill(
324331
self,
325332
model: Callable[..., Any],
326-
tokens: List[int],
333+
tokens: Union[List[int], torch.Tensor],
327334
) -> torch.Tensor:
328335
if self.cache_full:
329336
raise RuntimeError("KV cache is full.")
@@ -336,18 +343,21 @@ def prefill(
336343
)
337344
)
338345

346+
if isinstance(tokens, list):
347+
tokens = torch.tensor([tokens], dtype=torch.int32)
348+
339349
logits = None
340350
all_logits = None
341-
for i in range(0, len(tokens), self.input_len):
342-
logits = self._run_once(model, tokens[i : i + self.input_len])[0]
351+
for i in range(0, tokens.size(1), self.input_len):
352+
logits = self._run_once(model, tokens[:, i : i + self.input_len])[0]
343353
if self.config.generate_full_logits:
344354
if all_logits is None:
345355
all_logits = logits
346356
else:
347357
all_logits = torch.cat([all_logits, logits], dim=1)
348358

349359
if self.config.generate_full_logits:
350-
return all_logits[:, : len(tokens), :]
360+
return all_logits[:, : tokens.size(1), :]
351361

352362
return logits
353363

@@ -510,15 +520,16 @@ def lookahead_decode( # noqa: C901
510520
def _run_once(
511521
self,
512522
model: Callable[..., Any],
513-
tokens: List[int],
523+
tokens: Union[List[int], torch.Tensor],
514524
non_padded_len: Optional[int] = None,
515525
freqs_cos_override: Optional[torch.Tensor] = None,
516526
freqs_sin_override: Optional[torch.Tensor] = None,
517527
):
518-
n_tokens = len(tokens)
528+
if isinstance(tokens, list):
529+
tokens = torch.tensor([tokens], dtype=torch.int32)
530+
n_tokens = tokens.size(1)
519531
if n_tokens < self.input_len:
520-
tokens += [0] * (self.input_len - n_tokens)
521-
tokens = torch.tensor([tokens], dtype=torch.int32) # pyre-ignore[9]
532+
tokens = F.pad(tokens, (0, self.input_len - n_tokens))
522533
if freqs_cos_override is None:
523534
freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len]
524535
if freqs_sin_override is None:

examples/models/llama/tests/test_static_attention.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def test_with_style(style):
195195
test_with_style("shift_pointer")
196196
test_with_style("smart_mask")
197197

198-
def _get_test_transformers(self, config, attention_type="static"):
198+
def _get_test_transformers(self, config, attention_type="static", use_conv2d=False):
199199
mha_transformer = construct_transformer(config).eval()
200200

201201
config = copy.copy(config)
@@ -207,6 +207,8 @@ def _get_test_transformers(self, config, attention_type="static"):
207207
):
208208
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)
209209
static_layer.attention.adopt_hf_rope()
210+
if use_conv2d:
211+
static_layer.linear_to_conv2d()
210212
config.use_hf_rope = True
211213

212214
return mha_transformer, static_transformer, config
@@ -220,7 +222,8 @@ def test_within_transformer(self):
220222
n_layers=4,
221223
vocab_size=128,
222224
)
223-
x = torch.randint(config.vocab_size, (1, config.max_seq_len))
225+
batch_size = 3
226+
x = torch.randint(config.vocab_size, (batch_size, config.max_seq_len))
224227
n_chunks = 3
225228
chunk_len = config.max_seq_len // n_chunks
226229
cache_len = config.max_seq_len - chunk_len
@@ -235,13 +238,13 @@ def test(style, attention_type):
235238
expected = mha_transformer(x)
236239

237240
mgr = StaticAttentionIOManager(
238-
static_config, chunk_len, cache_len, style=style
241+
static_config, chunk_len, cache_len, style=style, batch_size=batch_size
239242
)
240243
ys = []
241244
for i in range(n_chunks):
242245
y_i = mgr.prefill(
243246
static_transformer,
244-
x[0][i * chunk_len : (i + 1) * chunk_len].tolist(),
247+
x[:, i * chunk_len : (i + 1) * chunk_len],
245248
)
246249
ys.append(y_i)
247250

@@ -300,3 +303,59 @@ def test_lookahead_decode(self):
300303
ngram_caches=ngram_caches,
301304
)
302305
self.assertEqual(lookahead_output[: len(ref_output)], ref_output)
306+
307+
def test_batched_export_with_backprop(self):
308+
config = ModelArgs(
309+
dim=64,
310+
n_heads=4,
311+
n_kv_heads=2,
312+
max_seq_len=128,
313+
n_layers=4,
314+
vocab_size=128,
315+
generate_full_logits=True,
316+
)
317+
_, static_transformer, static_config = self._get_test_transformers(config)
318+
batch_size = 4
319+
input_len = 32
320+
cache_len = static_config.max_seq_len - input_len
321+
mgr = StaticAttentionIOManager(
322+
static_config, input_len, cache_len, batch_size=batch_size
323+
)
324+
example_inputs = (
325+
torch.zeros(batch_size, input_len),
326+
{
327+
"masks": mgr.masks,
328+
"freqs_cos_override": mgr.freqs_cos[:input_len],
329+
"freqs_sin_override": mgr.freqs_sin[:input_len],
330+
"in_cache_state": (mgr.k_caches, mgr.v_caches),
331+
},
332+
)
333+
batched_gm = torch.export.export(static_transformer, example_inputs).module()
334+
335+
# Test backprop
336+
for _ in range(10):
337+
x = torch.randint(config.vocab_size, (batch_size, input_len))
338+
y = mgr.prefill(batched_gm, x)
339+
loss = torch.nn.functional.cross_entropy(
340+
y, torch.rand(batch_size, input_len, config.vocab_size)
341+
)
342+
loss.backward()
343+
mgr.reset()
344+
345+
# Test loading state dict into a non batched graph for inference
346+
mgr = StaticAttentionIOManager(
347+
static_config, input_len, cache_len, batch_size=1
348+
)
349+
example_inputs = (
350+
torch.zeros(1, input_len),
351+
{
352+
"masks": mgr.masks,
353+
"freqs_cos_override": mgr.freqs_cos[:input_len],
354+
"freqs_sin_override": mgr.freqs_sin[:input_len],
355+
"in_cache_state": (mgr.k_caches, mgr.v_caches),
356+
},
357+
)
358+
non_batched_gm = torch.export.export(
359+
static_transformer, example_inputs
360+
).module()
361+
non_batched_gm.load_state_dict(batched_gm.state_dict())

0 commit comments

Comments
 (0)