Skip to content

Commit c2a25ee

Browse files
committed
Bring refactor up to date with current script
Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
1 parent 3debfea commit c2a25ee

File tree

1 file changed

+13
-34
lines changed

1 file changed

+13
-34
lines changed

aiu_fms_testing_utils/scripts/drive_paged_programs.py

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ def __prepare_inputs(
213213
allow_truncation: bool,
214214
enforce_sizes: List[int] = [],
215215
seed: int = 0,
216-
pad_multiple: int = 64,
217216
):
218217
start = time.time()
219218
prompts_and_sizes, sample_key = sampler(
@@ -226,7 +225,6 @@ def __prepare_inputs(
226225
enforce_sizes=enforce_sizes,
227226
truncation=allow_truncation,
228227
return_key=True,
229-
pad_multiple=pad_multiple,
230228
)
231229
end = time.time()
232230
if local_rank == 0:
@@ -239,6 +237,10 @@ def __prepare_inputs(
239237
encoded = encoded[:seq_length]
240238
prompt_list.append(encoded)
241239

240+
if not prompt_list:
241+
raise ValueError(
242+
f"No valid prompt sample exists in dataset for input shape (Batch Size={batch_size}, Seq Length={seq_length})"
243+
)
242244
if len(prompt_list) < batch_size:
243245
dprint(
244246
f"You requested {batch_size} prompts but we were only able to get {len(prompt_list)} valid prompts. We will be repeating the first prompt."
@@ -292,7 +294,7 @@ def __load_validation_info(
292294
return None
293295

294296

295-
def parse_program_limit(limit_str: str) -> tuple[int, str]:
297+
def parse_program_limit(limit_str: str) -> tuple[int, str | None]:
296298
matcher = re.compile(r"^(<|>|<=|>=|==)(\d+)")
297299

298300
# Default limit to min to maintain backwards compat
@@ -483,7 +485,6 @@ def get_valid_prompts(
483485
allow_truncation: bool,
484486
custom_shape: Optional[Tuple[int, int]],
485487
pad_multiple: int,
486-
prefill_chunk_size: int,
487488
):
488489
# select prompts that fit the batch size criteria
489490
valid_prompts = []
@@ -500,7 +501,6 @@ def get_valid_prompts(
500501
dataset_path=dataset_path,
501502
allow_truncation=allow_truncation,
502503
enforce_sizes=enforce_sizes,
503-
pad_multiple=pad_multiple,
504504
)
505505
valid_prompts = [
506506
(
@@ -550,28 +550,14 @@ def get_valid_prompts(
550550
# if there does not exist enough sequence sizes between this range, we will cycle back to the beginning
551551
# in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user
552552
enforce_sizes = [valid_prompt_shape[1]]
553-
if (
554-
enforce_homogeneous_prompt_programs
555-
or prefill_chunk_size > 0
556-
):
557-
# if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
558-
tkv_cutoff = (
559-
1 << (valid_prompt_shape[1].bit_length() - 1)
560-
if enforce_homogeneous_prompt_programs
561-
else pad_multiple
562-
)
553+
if enforce_homogeneous_prompt_programs:
554+
# this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
555+
tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1)
563556
possible_seq_lengths = [
564-
_
565-
for _ in range(
566-
tkv_cutoff, valid_prompt_shape[1], pad_multiple
567-
)
557+
_ for _ in range(tkv_cutoff, valid_prompt_shape[1], pad_multiple)
568558
]
569559
# favor sequences that are close to the valid prompt length
570560
possible_seq_lengths.reverse()
571-
# add the valid prompt size to the end since it will already exist in the above enforce_sizes
572-
possible_seq_lengths = possible_seq_lengths + [
573-
valid_prompt_shape[1]
574-
]
575561
enforce_sizes = enforce_sizes + list(
576562
itertools.islice(
577563
itertools.cycle(possible_seq_lengths),
@@ -587,7 +573,6 @@ def get_valid_prompts(
587573
dataset_path=dataset_path,
588574
allow_truncation=allow_truncation,
589575
enforce_sizes=enforce_sizes,
590-
pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size)
591576
)
592577
valid_prompts.append(
593578
(
@@ -783,6 +768,7 @@ def main():
783768

784769
is_fp8 = "fp8" in args.attention_type
785770
CPU_DTYPE = "fp8" if is_fp8 else "fp32"
771+
PAD_MULTIPLE = 64
786772

787773
torch.manual_seed(42)
788774
torch.set_grad_enabled(False)
@@ -838,13 +824,7 @@ def main():
838824
# warmup with any input so compiler produces criteria json
839825
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
840826
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
841-
pad_multiple = 64
842-
if args.prefill_chunk_size > 0:
843-
assert args.prefill_chunk_size % 64 == 0, (
844-
"Chunk size must be a multiple of the page size"
845-
)
846-
pad_multiple = args.prefill_chunk_size
847-
prompt_list = [torch.arange(0, pad_multiple, dtype=torch.int64)]
827+
prompt_list = [torch.arange(0, PAD_MULTIPLE, dtype=torch.int64)]
848828
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
849829
if is_fp8:
850830
prompt_list = prompt_list * 2
@@ -895,7 +875,7 @@ def main():
895875
# FIXME: filter condition for this on prompt and batch
896876
program_map = get_programs_prompts(
897877
program_criteria_list=program_criteria_list,
898-
multiple=pad_multiple,
878+
multiple=PAD_MULTIPLE,
899879
max_batch_size=max_batch_size,
900880
max_tkv=max_tkv,
901881
program_cycles=args.max_new_tokens,
@@ -920,8 +900,7 @@ def main():
920900
sampler=sampler,
921901
allow_truncation=allow_truncation,
922902
custom_shape=custom_shape,
923-
pad_multiple=pad_multiple,
924-
prefill_chunk_size=args.prefill_chunk_size,
903+
pad_multiple=PAD_MULTIPLE,
925904
)
926905

927906
## RUN VALIDATION AND TESTS ##

0 commit comments

Comments
 (0)