-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathdrive_paged_programs.py
More file actions
941 lines (860 loc) · 35.2 KB
/
drive_paged_programs.py
File metadata and controls
941 lines (860 loc) · 35.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
import argparse
from dataclasses import dataclass
import datetime
import itertools
import json
import os
from pathlib import Path
import random
import time
from itertools import dropwhile
import re
import torch
from fms.models import get_model
from fms.utils.generation import pad_input_ids
from torch import distributed as dist
from torch.fx.experimental import _config as fx_config
from transformers import AutoTokenizer
import numpy as np
from aiu_fms_testing_utils.testing.validation import (
GoldenTokenHook,
LogitsExtractorHook,
capture_level_1_metrics,
extract_validation_information,
find_validation_info_path,
get_validation_info_path,
load_validation_information,
top_k_loss_calculator,
)
from aiu_fms_testing_utils.utils import (
get_pad_size,
sample_rag_factoid_requests,
sample_sharegpt_requests,
stagger_region,
warmup_model,
)
from aiu_fms_testing_utils.utils.aiu_setup import aiu_dist_setup, dprint, local_rank
from aiu_fms_testing_utils.utils.paged import (
ProgramCriteria,
get_programs_prompts,
KVCACHE_NUM_BLOCKS_HINT,
)
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string
parser = argparse.ArgumentParser(
description="Script which will drive paged programs for debugging"
)
parser.add_argument(
"--programs",
metavar="N",
type=str,
nargs="*",
default=[],
help="""
The list of programs to run. This would take a list where each element would be one of program_id OR <program_id>:<min_batch>,<min_prompt_length>.
If program_id is specified any prompt that would result in this program would be selected.
If <program_id>:<min_batch>,<min_prompt_length> is specified, then with the given program_id, select a prompt that satisfies min_batch and min_prompt_length (if none exists, a message will be printed to warn the user)
If this list is empty, each program will be run once with any prompt that would result in this program being selected.
""",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=8,
help="set this if you want to change the number of tokens generated per sequence (1 prefill + max_new_tokens-1 decodes). Note: If this value is larger than 64, this may result in switching decode programs mid generation",
)
parser.add_argument(
"--distributed",
action="store_true",
help="This is a distributed job (multiple instances run with RANK+WORLD_SIZE)",
)
parser.add_argument(
"--model_variant",
type=str,
default="ibm-ai-platform/micro-g3.3-8b-instruct-1b",
help="The model id or path to use for this test. Note: must be a huggingface format",
)
parser.add_argument(
"--timing",
type=str,
choices=["e2e", "per-token"],
default="",
help="if set, how to time the generation of tokens, e2e or per-token",
)
parser.add_argument(
"--program_criteria_json_path",
type=str,
help="path to json file containing the program criteria list",
)
parser.add_argument(
"--dataset_path",
type=str,
help="path to dataset",
)
parser.add_argument(
"--dataset_type",
type=str,
choices=["rag_factoid", "sharegpt", "custom"],
default="sharegpt",
help="selects the correct dataset type for sampling. Must be one of rag_factoid or sharegpt",
)
parser.add_argument(
"--test_type",
type=str,
choices=["tokens", "metrics"],
default="metrics",
help="set the type of the test that you would like to run. If metrics, will inject tokens and get metrics. If tokens, will not inject tokens and get tokens",
)
parser.add_argument(
"--cross_entropy_threshold",
type=float,
default=2.5,
help="threshold to denote passing/failing a given iteration",
)
parser.add_argument(
"--cross_entropy_threshold_path",
type=str,
default=None,
help="path to a file with all expected cross-entropy loss thresholds per program, per sequence. This should generally point to the artifact generated by argument `generate_metrics_path`",
)
parser.add_argument(
"--per_sequence_failure_rate_threshold",
type=float,
default=0.1,
help="the threshold which denotes whether to pass or fail the test for a given sequence.",
)
parser.add_argument(
"--failure_rate_threshold",
type=float,
default=0.1,
help="the threshold which denotes whether to pass or fail the test. The failure threshold is defined as the number of failing iterations (cross_entropy) over the total iterations. If this value exceeds the failure_rate_threshold, we will fail the test",
)
parser.add_argument(
"--attention_type",
type=str,
default="paged",
choices=["paged", "paged_fp8"],
help="The attention type to use",
)
parser.add_argument(
"--prefill_chunk_size",
type=int,
default=0,
help="if > 0, activate chunked prefill, with chunk_size=this_argument. Only works with paged attention variants.",
)
parser.add_argument(
"--stagger_load",
type=int,
default=0,
help="Limit the number of concurrent processes executing the model loading phase. Set to 0 to allow all processes",
)
parser.add_argument(
"--stagger_update_lazyhandle",
type=int,
default=0,
help="Limit the number of concurrent processes executing the AIU update_lazyhandle phase. Set to 0 to allow all processes",
)
parser.add_argument(
"--dist_timeout",
type=int,
default=0,
help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group",
)
parser.add_argument(
"--skip_validation",
action="store_true",
help="set to true to skip cpu validation",
)
parser.add_argument(
"--validation_info_outputs_dir",
type=str,
default="/home/senuser/models/validation_info",
help="path to directory containing validation info outputs",
)
parser.add_argument(
"--save_validation_info_outputs",
action="store_true",
help="set to true to save cpu validation outputs for later consumption",
)
parser.add_argument(
"--prioritize_large_batch_sizes",
action="store_true",
help="set to true if you would like to prioritize large batch sizes",
)
parser.add_argument(
"--enforce_homogeneous_prompt_programs",
action="store_true",
help="set to true ensure that all prompts hit the same prompt program for a given test",
)
parser.add_argument(
"--generate_metrics_path",
type=str,
default=None,
help="if set, will bypass AIU model processing and generate cross-entropy loss thresholds used for testing, and save the metrics to the given path",
)
args = parser.parse_args()
# interleave the decodes for programs (not 3 separate generates)
max_new_tokens = args.max_new_tokens
model_variant = args.model_variant
DATASET_PATH = args.dataset_path
save_validation_info_outputs = args.save_validation_info_outputs
generate_metrics = args.generate_metrics_path is not None
tokenizer = AutoTokenizer.from_pretrained(model_variant)
custom_shape = None
default_cross_entropy_threshold = float(args.cross_entropy_threshold)
program_threshold_dict = {}
# if the path exists, load it as a json
if args.cross_entropy_threshold_path is not None and os.path.exists(
args.cross_entropy_threshold_path
):
with open(args.cross_entropy_threshold_path, "r") as f:
program_threshold_dict = json.load(f)
if args.dataset_type == "custom":
if local_rank == 0:
dprint(
"Using custom prompts from user, programs parameter will be ignored as it will be determined by user prompt"
)
directory = Path(DATASET_PATH)
if not directory.is_dir():
dprint("when using a custom dataset, you must provide a directory")
exit()
result = []
for fp in directory.iterdir():
if fp.is_file():
try:
content = fp.read_text()
result.append((content, get_pad_size(len(tokenizer.encode(content)))))
except Exception as e:
print(f"Error while reading {fp} for custom dataset: {e}")
exit()
custom_shape = (len(result), max([_[1] for _ in result]))
def __custom_line_sampler(*args, **kwargs):
return_key = kwargs.get("return_key", False)
sample_key = format_kwargs_to_string(**kwargs)
if return_key:
return result, sample_key
return result
sampler = __custom_line_sampler
allow_truncation = False
elif args.dataset_type == "rag_factoid":
sampler = sample_rag_factoid_requests
allow_truncation = False
elif args.dataset_type == "sharegpt":
sampler = sample_sharegpt_requests
allow_truncation = True
else:
raise ValueError("dataset_type must be one of rag_factoid or sharegpt")
if args.skip_validation and args.test_type == "metrics":
dprint("When skipping validation, only test_type will be ignored")
USE_DISTRIBUTED = args.distributed
TIMING = args.timing
is_fp8 = "fp8" in args.attention_type
attention_map = {
"sdpa": "sdpa_causal",
"paged": "spyre_paged_attn",
"math_fp8": "math_fp8",
"paged_fp8": "spyre_paged_attn_fp8",
}
ATTN_NAME = attention_map[args.attention_type]
CPU_DTYPE = "fp8" if is_fp8 else "fp32"
torch.manual_seed(42)
torch.set_grad_enabled(False)
os.environ["COMPILATION_MODE"] = "offline_decoder"
os.environ["DT_PROG_CRITERIA_FILEPATH"] = args.program_criteria_json_path
if (
"VLLM_DT_MAX_CONTEXT_LEN" not in os.environ
or "VLLM_DT_MAX_BATCH_SIZE" not in os.environ
):
if local_rank == 0:
dprint(
"Please specify VLLM_DT_MAX_CONTEXT_LEN and VLLM_DT_MAX_BATCH_SIZE environment variables"
)
exit()
max_batch_size = int(os.environ["VLLM_DT_MAX_BATCH_SIZE"])
max_tkv = int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"])
def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0):
start = time.time()
prompts_and_sizes, sample_key = sampler(
DATASET_PATH,
batch_size,
tokenizer,
32,
seq_length * 2 if allow_truncation else seq_length,
seed,
enforce_sizes=enforce_sizes,
truncation=allow_truncation,
return_key=True,
)
end = time.time()
if local_rank == 0:
dprint(f"extracted prompts in {(end - start):.4f} seconds")
prompt_list = []
for prompt, size in prompts_and_sizes:
encoded = tokenizer.encode(prompt, return_tensors="pt").squeeze(0)
if size > seq_length:
assert allow_truncation
encoded = encoded[:seq_length]
prompt_list.append(encoded)
if not prompt_list:
raise ValueError(
f"No valid prompt sample exists in dataset for input shape (Batch Size={batch_size}, Seq Length={seq_length})"
)
if len(prompt_list) < batch_size:
dprint(
f"You requested {batch_size} prompts but we were only able to get {len(prompt_list)} valid prompts. We will be repeating the first prompt."
)
prompt_list = [prompt_list[0]] * (batch_size - len(prompt_list)) + prompt_list
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
return input_ids, extra_kwargs, sample_key
def __maybe_prepare_fp8_weights(model_in, is_fp8):
if is_fp8:
for name, param in model_in.named_parameters():
if param.dtype == torch.bfloat16:
if param.max() > torch.finfo(torch.float16).max:
dprint(
f"[WARNING] You are casting param {name} to fp16, which will cause loss of accuracy. You can ignore this warning if this is intended."
)
param.data = param.data.to(dtype=torch.float16)
def __load_validation_info(
model_variant,
batch_size,
seq_length,
max_new_tokens,
tokenizer,
seed,
attn_type: str,
**kwargs,
):
sample_key = kwargs.get("sample_key", None)
full_path = find_validation_info_path(
args.validation_info_outputs_dir,
model_variant,
batch_size,
seq_length,
max_new_tokens,
seed,
attn_type,
version_allow_decrement=True,
dtype=CPU_DTYPE,
sample_key=sample_key,
)
if full_path is not None:
dprint(f"cpu validation info found for seed={seed} -- loading it")
return load_validation_information(full_path, "logits", batch_size, tokenizer)
else:
return None
model_path_kwargs = {}
if os.path.exists(model_variant):
model_path_kwargs = {"model_path": model_variant}
else:
model_path_kwargs = {"variant": model_variant}
distributed_kwargs = {}
if USE_DISTRIBUTED:
if generate_metrics:
torch.cuda.set_device(local_rank)
if args.dist_timeout > 0:
# Default timeout:
# https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
dist.init_process_group(
timeout=datetime.timedelta(minutes=args.dist_timeout), backend="gloo"
)
dprint(f"NOTICE: init_process_group timeout set to {args.dist_timeout} minutes")
else:
dist.init_process_group(backend="gloo")
aiu_dist_setup(dist.get_rank(), dist.get_world_size())
distributed_kwargs["distributed_strategy"] = "tp"
distributed_kwargs["group"] = dist.group.WORLD
save_validation_info_outputs = save_validation_info_outputs and (
dist.get_rank() == 0
)
with stagger_region(args.stagger_load):
model = get_model(
architecture="hf_pretrained",
device_type="cuda" if generate_metrics else "cpu",
data_type=None if is_fp8 else torch.float16,
fused_weights=False,
**model_path_kwargs,
**distributed_kwargs,
)
model.eval()
fx_config.backed_size_oblivious = True
if not generate_metrics:
model.compile(backend="sendnn", options={"sendnn.dynamic": True})
__maybe_prepare_fp8_weights(model, is_fp8)
if not args.skip_validation:
with stagger_region(args.stagger_load):
validation_model = get_model(
architecture="hf_pretrained",
device_type="cpu",
data_type=None if is_fp8 else torch.float32,
fused_weights=False,
**model_path_kwargs,
**distributed_kwargs,
)
validation_model.eval()
# warmup with any input so compiler produces criteria json
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
if is_fp8:
prompt_list = prompt_list * 2
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64)
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
extra_kwargs["attn_name"] = ATTN_NAME
if (
"granite-3.3-8b-instruct" in model_variant
and USE_DISTRIBUTED
and dist.get_world_size() == 4
):
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
if not generate_metrics:
warmup_model(
model,
input_ids,
max_new_tokens=max_new_tokens,
compile_dynamic_sendnn=True,
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
prefill_chunk_size=args.prefill_chunk_size,
**extra_kwargs,
)
if USE_DISTRIBUTED:
# wait for rank0 to be finished as it is the only one generating the criteria json
# this is needed since otherwise we may run into a race condition
torch.distributed.barrier()
@dataclass
class ProgramInfo:
program_id: str
batch_size_limit: int
batch_size_limit_type: str
prompt_length_limit: int
prompt_length_limit_type: str
def parse_program_limit(limit_str: str) -> tuple[int, str]:
matcher = re.compile(r"^(<|>|<=|>=|==)(\d+)")
# Default limit to min to maintain backwards compat
try:
limit_type = ">="
limit_val = int(limit_str)
except ValueError:
limit_type = None
match = matcher.fullmatch(limit_str)
if match is None:
raise ValueError("Program not well formatted, wrong limit type")
limit_type = match.group(1)
limit_val = int(match.group(2))
return limit_val, limit_type
with open(args.program_criteria_json_path, "r") as f:
program_criteria_json_list = json.load(f)["programs"]
program_criteria_list = []
for i, d in enumerate(program_criteria_json_list):
program_criteria_list.append(
ProgramCriteria(
i,
d["max_batch"],
d["max_tkv"],
d["batch_granularity"],
d["tkv_granularity"],
)
)
programs = []
for program_str in args.programs:
enforce_prompt_split = program_str.split(":")
program_id = enforce_prompt_split[0]
if len(enforce_prompt_split) == 1:
programs.append(
ProgramInfo(program_id, 0, ">=", 0, ">=")
) # this will always satisfy
else:
enforce_batch_size, enforce_prompt_length = (
_ for _ in enforce_prompt_split[1].split(",")
)
# Default limit to min to maintain backwards compat
enforce_batch_size_val, enforce_batch_size_type = parse_program_limit(
enforce_batch_size
)
enforce_prompt_length_val, enforce_prompt_length_type = parse_program_limit(
enforce_prompt_length
)
programs.append(
ProgramInfo(
program_id,
enforce_batch_size_val,
enforce_batch_size_type,
enforce_prompt_length_val,
enforce_prompt_length_type,
)
)
if len(programs) == 0:
programs = [
ProgramInfo(str(p.program_id), 0, ">=", 0, ">=")
for p in program_criteria_list
]
# FIXME: filter condition for this on prompt and batch
program_map = get_programs_prompts(
program_criteria_list,
multiple=64,
max_batch_size=max_batch_size,
max_tkv=max_tkv,
program_cycles=max_new_tokens,
prioritize_large_batch_sizes=args.prioritize_large_batch_sizes,
)
for v in program_map.values():
random.Random(42).shuffle(v)
# select prompts that fit the batch size criteria
valid_prompts = []
if custom_shape:
for program_criteria_seq, valid_prompt_shapes in program_map.items():
for valid_prompt_shape in valid_prompt_shapes:
if valid_prompt_shape == custom_shape:
enforce_sizes = [valid_prompt_shape[1]]
input_ids, extra_kwargs, sample_key = __prepare_inputs(
valid_prompt_shape[0],
valid_prompt_shape[1],
tokenizer,
enforce_sizes=enforce_sizes,
)
valid_prompts = [
(
program_criteria_seq[0].program_id,
custom_shape,
input_ids,
extra_kwargs,
sample_key,
)
]
break
if len(valid_prompts) > 0:
break
else:
for program_info in programs:
program_id = program_info.program_id
batch_size_limit = program_info.batch_size_limit
batch_size_limit_type = program_info.batch_size_limit_type
prompt_length_limit = program_info.prompt_length_limit
prompt_length_limit_type = program_info.prompt_length_limit_type
filtered_program_map = program_map
if program_id.isnumeric():
filtered_program_map = {
k: v
for k, v in program_map.items()
if k[0] == program_criteria_list[int(program_id)]
}
used_keys = set()
# for each program, we need to check if we have a shape that satisfies the --programs request
for program_seq_key, valid_prompt_shapes in filtered_program_map.items():
# if ? or numeric => we need to check if we have found at least one valid key to stop
if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0:
break
# if * => we need to see if we have found the first key to see if we should skip
elif program_id == "*" and program_seq_key[0] in used_keys:
continue
for valid_prompt_shape in valid_prompt_shapes:
# make sure the criteria for batch limit and prompt limit is satisfied
# eval is safe here because we have limited what type and limit can be before
batch_check = eval(
f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}"
)
prompt_check = eval(
f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}"
)
if batch_check and prompt_check:
# when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length
# if there does not exist enough sequence sizes between this range, we will cycle back to the beginning
# in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user
enforce_sizes = [valid_prompt_shape[1]]
if args.enforce_homogeneous_prompt_programs:
# this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1)
possible_seq_lengths = [
_ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64)
]
# favor sequences that are close to the valid prompt length
possible_seq_lengths.reverse()
enforce_sizes = enforce_sizes + list(
itertools.islice(
itertools.cycle(possible_seq_lengths),
valid_prompt_shape[0] - 1,
)
)
try:
input_ids, extra_kwargs, sample_key = __prepare_inputs(
valid_prompt_shape[0],
valid_prompt_shape[1],
tokenizer,
enforce_sizes=enforce_sizes,
)
valid_prompts.append(
(
program_seq_key[0],
valid_prompt_shape,
input_ids,
extra_kwargs,
sample_key,
)
)
used_keys.add(program_seq_key[0])
break
except ValueError:
dprint(
f"No valid sample exists in dataset for this input shape {valid_prompt_shape}"
)
if len(used_keys) == 0 and local_rank == 0:
dprint(
f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}"
)
# metric calculator based on the cross-entropy and mean diff for each decode step
def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
cross_entropy = torch.nn.CrossEntropyLoss()(
r, t.softmax(dim=1).to(device="cpu", dtype=torch.float32)
)
diff = torch.mean(
torch.abs(
r.softmax(dim=1).to(dtype=torch.float32)
- t.softmax(dim=1).to(device="cpu", dtype=torch.float32)
)
)
return (cross_entropy, diff)
per_sentence_failed_cases = []
aggregate_failed_cases = []
# for each program and valid prompt (batch size, sequence length)
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
extra_kwargs["attn_name"] = ATTN_NAME
if (
"granite-3.3-8b-instruct" in model_variant
and USE_DISTRIBUTED
and dist.get_world_size() == 4
):
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
if local_rank == 0:
dprint(f"*** testing program {program_id} ***")
dprint(
f"program id: {program_id}, valid prompt: {valid_prompt}, input shape: {input_ids.shape}"
)
if not args.skip_validation:
# attempt to load the cpu validation info if it is already computed
cpu_validation_info = __load_validation_info(
model_variant,
valid_prompt[0],
valid_prompt[1],
max_new_tokens,
tokenizer,
seed=0,
attn_type=ATTN_NAME,
sample_key=sample_key,
)
# if the cpu validation info is not yet computed, compute it
if cpu_validation_info is None:
cpu_validation_info = extract_validation_information(
validation_model,
input_ids,
max_new_tokens,
LogitsExtractorHook(),
attn_algorithm="math",
**extra_kwargs,
)
# save the cpu validation info for later consumption
if save_validation_info_outputs:
cpu_validation_info.save(
get_validation_info_path(
args.validation_info_outputs_dir,
model_variant,
valid_prompt[0],
valid_prompt[1],
max_new_tokens,
0,
ATTN_NAME,
dtype=CPU_DTYPE,
sample_key=sample_key,
)
)
if args.test_type == "metrics":
# if we are generating metrics, all inputs should be on cuda device
if generate_metrics:
input_ids = input_ids.to("cuda")
extra_kwargs = {
k: v.to("cuda") if isinstance(v, torch.Tensor) else v
for k, v in extra_kwargs.items()
}
aiu_validation_info = extract_validation_information(
model,
input_ids,
max_new_tokens,
GoldenTokenHook(cpu_validation_info.get_info("tokens")),
last_n_tokens=64,
timing=TIMING,
prefill_chunk_size=args.prefill_chunk_size,
**extra_kwargs,
)
# capture all level 1 metrics
level_1_metrics = capture_level_1_metrics(
cpu_validation_info.get_info("logits"),
aiu_validation_info.get_info("logits"),
top_k_loss_calculator(20, __metric_calculator),
)
cpu_tokens = cpu_validation_info.get_info("tokens")
for sentence_idx, token_idx, metrics_value in level_1_metrics:
if local_rank == 0:
aiu_token = torch.argmax(
aiu_validation_info.get_info("logits")[sentence_idx][token_idx],
dim=-1,
)
cpu_token = cpu_tokens[sentence_idx][valid_prompt[1] + token_idx]
aiu_str = tokenizer.decode(aiu_token).replace(
"\n", "<NEWLINE>"
) # remove newlines for readability
cpu_str = tokenizer.decode(cpu_token).replace(
"\n", "<NEWLINE>"
) # remove newlines for readability
dprint(
f'For Program {program_id} in sentence {sentence_idx + 1}: the metric for token {token_idx} is {metrics_value}, AIU ID="{aiu_token.item()}" | STR="{aiu_str}" -- CPU ID="{cpu_token.item()}" | CPU STR="{cpu_str}"'
)
# if generating metrics, get the 99th percentile ce threshold per sentence
# otherwise test the thresholds
if generate_metrics:
sentence_ce_dict = {}
for sentence_idx, token_idx, metrics_value in level_1_metrics:
sentence_ce_dict.setdefault(sentence_idx, [])
sentence_ce_dict[sentence_idx].append(metrics_value[0])
sentence_ce_threshold = {
k: np.percentile(v, 99) for k, v in sentence_ce_dict.items()
}
if local_rank == 0:
dprint(
f"Program {str(program_id.program_id)} produced the following thresholds:\n{sentence_ce_threshold}"
)
program_threshold_dict[(str(program_id.program_id), sample_key)] = (
sentence_ce_threshold
)
else:
sentence_failures_dict = {}
for sentence_idx, token_idx, metrics_value in level_1_metrics:
program_threshold_key = f"{str(program_id.program_id)},{sample_key}"
if (
len(program_threshold_dict) != 0
and program_threshold_key not in program_threshold_dict
and local_rank == 0
):
dprint(
f"could not find the following key {program_threshold_key}, defaulting to {default_cross_entropy_threshold}"
)
ce_threshold = program_threshold_dict.get(
program_threshold_key,
{str(sentence_idx): default_cross_entropy_threshold},
)[str(sentence_idx)]
sentence_failures_dict.setdefault(sentence_idx, 0)
if metrics_value[0].item() >= ce_threshold:
sentence_failures_dict[sentence_idx] += 1
for sentence_idx, failure_count in sentence_failures_dict.items():
per_sentence_failure_rate = failure_count / max_new_tokens
if (
per_sentence_failure_rate
>= args.per_sequence_failure_rate_threshold
):
per_sentence_failed_cases.append(
(
program_id,
valid_prompt,
sentence_idx,
per_sentence_failure_rate,
)
)
aggregate_failure_rate = sum(sentence_failures_dict.values()) / (
max_new_tokens * len(sentence_failures_dict)
)
if aggregate_failure_rate >= args.failure_rate_threshold:
aggregate_failed_cases.append(
(
program_id,
valid_prompt,
aggregate_failure_rate,
)
)
elif args.test_type == "tokens":
aiu_validation_info = extract_validation_information(
model,
input_ids,
max_new_tokens,
None,
last_n_tokens=64,
timing=TIMING,
prefill_chunk_size=args.prefill_chunk_size,
**extra_kwargs,
)
if local_rank == 0:
for sentence_idx, (reference_sentence, test_sentence) in enumerate(
zip(
cpu_validation_info.get_info("tokens"),
aiu_validation_info.get_info("tokens"),
)
):
tokens_prompt = [
t.item() for t in reference_sentence[:-max_new_tokens]
]
cpu_tokens_generated = [
t.item() for t in reference_sentence[-max_new_tokens:]
]
aiu_tokens_generated = [
t.item() for t in test_sentence[-max_new_tokens:]
]
tokens_prompt_without_pad = list(
dropwhile(lambda x: x == tokenizer.pad_token_id, tokens_prompt)
)
prompt_length = len(
[token_id for token_id in tokens_prompt_without_pad]
)
dprint(f"Prompt Length: {prompt_length}")
dprint(f"For Program {program_id} in sentence {sentence_idx + 1}:")
dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt_without_pad)}")
dprint(f"CPU tokens:\n{cpu_tokens_generated}")
dprint(f"AIU tokens:\n{aiu_tokens_generated}")
dprint(f"CPU output:\n{tokenizer.decode(cpu_tokens_generated)}")
dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}")
else:
raise ValueError("test type must be one of metrics or tokens")
else:
aiu_validation_info = extract_validation_information(
model,
input_ids,
max_new_tokens,
None,
last_n_tokens=64,
timing=TIMING,
prefill_chunk_size=args.prefill_chunk_size,
**extra_kwargs,
)
if local_rank == 0:
for sentence_idx, test_sentence in enumerate(
aiu_validation_info.get_info("tokens")
):
tokens_prompt = [t.item() for t in test_sentence[:-max_new_tokens]]
aiu_tokens_generated = [
t.item() for t in test_sentence[-max_new_tokens:]
]
dprint(f"For Program {program_id} in sentence {sentence_idx + 1}:")
dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt)}")
dprint(f"AIU tokens:\n{aiu_tokens_generated}")
dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}")
if generate_metrics and local_rank == 0:
with open(args.generate_metrics_path, "w") as f:
json_dict = {}
for program_seq, sentence_ce_threshold_dict in program_threshold_dict.items():
program_seq_key = ",".join(program_seq)
json_dict[program_seq_key] = {}
for sentence_i, ce_threshold in sentence_ce_threshold_dict.items():
json_dict[program_seq_key][sentence_i] = float(ce_threshold)
json.dump(json_dict, f, indent=4)
if not args.skip_validation and local_rank == 0:
if len(aggregate_failed_cases) != 0:
dprint("the test failed with the following aggregate cases:")
for failed_case in aggregate_failed_cases:
dprint(
f"Program ID: {failed_case[0]}, Prompt Shape: {failed_case[1]}, Failure Rate: {failed_case[2]}"
)
if len(per_sentence_failed_cases) != 0:
dprint("the test failed with the following per sentence cases:")
for failed_case in per_sentence_failed_cases:
dprint(
f"Program ID: {failed_case[0]}, Prompt Shape: {failed_case[1]}, Sentence Index: {failed_case[2]}, Failure Rate: {failed_case[3]}"
)
if len(aggregate_failed_cases) == 0 and len(per_sentence_failed_cases) == 0:
dprint("all tests passed")