Skip to content

Commit ee218a4

Browse files
authored
Merge pull request #133 from foundation-model-stack/drive_program_script_enhancements_last_n_tokens
Drive program script enhancements last n tokens
2 parents 4b7f51b + b782876 commit ee218a4

File tree

10 files changed

+27
-23
lines changed

10 files changed

+27
-23
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def extract_validation_information(
256256
post_iteration_hook,
257257
attn_algorithm=None,
258258
eos_token_id=None,
259-
only_last_token=False,
259+
last_n_tokens=0,
260260
timing="",
261261
**extra_kwargs,
262262
):
@@ -270,10 +270,10 @@ def extract_validation_information(
270270
attention_specific_kwargs["contiguous_cache"] = True
271271
attention_specific_kwargs["max_seq_len"] = input_ids.shape[1] + max_new_tokens
272272

273-
# Add only_last_token optimization
273+
# Add last_n_tokens optimization
274274
extra_generation_kwargs = {**extra_kwargs}
275-
if only_last_token:
276-
extra_generation_kwargs["only_last_token"] = only_last_token
275+
if last_n_tokens != 0:
276+
extra_generation_kwargs["last_n_tokens"] = last_n_tokens
277277
if attn_algorithm is not None:
278278
extra_generation_kwargs["attn_algorithm"] = attn_algorithm
279279

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def warmup_model(
8585
**extra_kwargs,
8686
)
8787

88-
extra_kwargs = {**_extra_kwargs, "only_last_token": "paged" not in attn_name}
88+
extra_kwargs = {**_extra_kwargs, "last_n_tokens": 64 if "paged" in attn_name else 1}
8989

9090
with stagger_region(stagger_update_lazyhandle):
9191
with torch_sendnn.warmup_mode():

aiu_fms_testing_utils/utils/paged.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ def generate(
8686
if extra_kwargs is not None:
8787
kwargs.update(extra_kwargs)
8888

89+
# if we didn't specify last_n_tokens and only_last_token is set to True, set last_n_tokens to 1, otherwise use default
90+
# we do this since the output shape of only_last_token is different and therefore would change the logic in generate
91+
if "last_n_tokens" not in kwargs and kwargs.get("only_last_token", False):
92+
kwargs["last_n_tokens"] = 1
93+
8994
is_fp8 = "fp8" in kwargs["attn_name"]
9095
if isinstance(input_ids, torch.Tensor):
9196
if len(input_ids.shape) == 1:
@@ -233,7 +238,7 @@ def generate(
233238
kwargs["current_tkv_mask"] = None
234239
kwargs["left_padded_prompt_mask"] = None
235240
kwargs["use_cache"] = use_cache
236-
only_last_token = kwargs.get("only_last_token", False)
241+
last_n_tokens = kwargs.get("last_n_tokens", 0)
237242

238243
prompt_length = input_ids.shape[1]
239244

@@ -296,21 +301,20 @@ def generate(
296301
t1._scale = current_kv_scales[layer_idx][0][seq_i].reshape(-1)
297302
t2._scale = current_kv_scales[layer_idx][1][seq_i].reshape(-1)
298303

299-
only_last_token = kwargs.get("only_last_token", False)
304+
last_n_tokens = kwargs.get("last_n_tokens", 0)
300305
output, current_kv_cache = model(
301306
input_ids_i,
302307
slot_mapping=slot_mapping_i,
303308
position_ids=position_ids_i,
304309
mask=mask_i,
305310
past_key_value_states=current_kv_cache,
306311
use_cache=kwargs["use_cache"],
307-
only_last_token=only_last_token,
312+
last_n_tokens=last_n_tokens,
308313
attn_name=kwargs["attn_name"],
309314
)
310315

311316
# only last token must be handled here to properly stack the tensors
312-
if not only_last_token:
313-
output = output[:, -1, :]
317+
output = output[:, -1, :]
314318

315319
# TODO: Figure out how to do this cleanly
316320
if "fp8" in kwargs["attn_name"]:
@@ -341,6 +345,7 @@ def generate(
341345
kwargs["position_ids"] = kwargs["position_ids"].clone(
342346
memory_format=torch.contiguous_format
343347
)
348+
kwargs["last_n_tokens"] = 1
344349

345350
# we no longer have a global pos_i, each sequence has its own pos_i
346351
slot_mapping = []
@@ -396,8 +401,7 @@ def generate(
396401
# typically this is done outside of prefill/decode logic, but since this logic already exists as part of the
397402
# conditional for prefill (since prefill does this within a loop for each batch size 1 prefill), we also provide
398403
# this same logic as part of the decode conditional
399-
if not only_last_token:
400-
logits = logits[:, -1, :]
404+
logits = logits[:, -1, :]
401405

402406
output = (logits, past_key_value_states)
403407

examples/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ prompt = template.format("Provide a list of instructions for preparing chicken s
5050
input_ids = tokenizer.encode(prompt, return_tensors="pt")
5151
input_ids, extra_generation_kwargs = pad_input_ids([input_ids.squeeze(0)], min_pad_length=math.ceil(input_ids.size(1)/64) * 64)
5252
# only_last_token optimization
53-
extra_generation_kwargs["only_last_token"] = True
53+
extra_generation_kwargs["last_n_tokens"] = 1
5454
# Set a desired number
5555
max_new_tokens = 16
5656
```

examples/run_granite3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
[input_ids.squeeze(0)], min_pad_length=math.ceil(input_ids.size(1) / 64) * 64
3535
)
3636
# only_last_token optimization
37-
extra_generation_kwargs["only_last_token"] = True
37+
extra_generation_kwargs["last_n_tokens"] = 1
3838
# Set a desired number
3939
max_new_tokens = 16
4040

scripts/drive_paged_programs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
632632
input_ids,
633633
max_new_tokens,
634634
GoldenTokenHook(cpu_validation_info.get_info("tokens")),
635-
only_last_token=False,
635+
last_n_tokens=64,
636636
timing=TIMING,
637637
**extra_kwargs,
638638
)
@@ -676,7 +676,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
676676
input_ids,
677677
max_new_tokens,
678678
None,
679-
only_last_token=False,
679+
last_n_tokens=64,
680680
timing=TIMING,
681681
**extra_kwargs,
682682
)
@@ -718,7 +718,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
718718
input_ids,
719719
max_new_tokens,
720720
None,
721-
only_last_token=False,
721+
last_n_tokens=64,
722722
timing=TIMING,
723723
**extra_kwargs,
724724
)

scripts/generate_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def write_csv(metrics, path, metric_name):
259259
ids.to("cuda"),
260260
args.max_new_tokens,
261261
None,
262-
only_last_token=True,
262+
last_n_tokens=1,
263263
**{k: v.to("cuda") for k, v in padding_kwargs.items()},
264264
)
265265
cuda_static_tokens = cuda_validation_info.get_info("tokens")
@@ -334,7 +334,7 @@ def write_csv(metrics, path, metric_name):
334334
ids.to("cuda"),
335335
args.max_new_tokens,
336336
GoldenTokenHook(cpu_validation_info.get_info("tokens"), "cuda"),
337-
only_last_token=True,
337+
last_n_tokens=1,
338338
**{k: v.to("cuda") for k, v in padding_kwargs.items()},
339339
)
340340

scripts/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def infer(use_cache, do_sample, warmup):
771771
global extra_generation_kwargs
772772
if extra_generation_kwargs is None:
773773
extra_generation_kwargs = {}
774-
extra_generation_kwargs["only_last_token"] = "paged" not in attn_name
774+
extra_generation_kwargs["last_n_tokens"] = 64 if "paged" in attn_name else 1
775775

776776
if not args.no_early_termination and not warmup:
777777
eos_token_id = tokenizer.eos_token_id

scripts/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""):
710710
args.max_new_tokens,
711711
post_iteration_hook,
712712
eos_token_id=None if args.no_early_termination else tokenizer.eos_token_id,
713-
only_last_token=True,
713+
last_n_tokens=1,
714714
timing=args.timing,
715715
**padding_kwargs,
716716
)

tests/models/test_decoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ def test_common_shapes(
588588
input_ids,
589589
max_new_tokens,
590590
None,
591-
only_last_token="paged" not in ATTN_NAME,
591+
last_n_tokens=64 if "paged" in ATTN_NAME else 1,
592592
timing=TIMING,
593593
**extra_kwargs,
594594
)
@@ -689,7 +689,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
689689
input_ids,
690690
max_new_tokens,
691691
GoldenTokenHook(cpu_static_tokens),
692-
only_last_token=ATTN_TYPE != "paged",
692+
last_n_tokens=64 if "paged" in ATTN_NAME else 1,
693693
timing=TIMING,
694694
**extra_kwargs,
695695
)

0 commit comments

Comments
 (0)