You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Default limit to min to maintain backwards compat
@@ -483,7 +485,6 @@ def get_valid_prompts(
483
485
allow_truncation: bool,
484
486
custom_shape: Optional[Tuple[int, int]],
485
487
pad_multiple: int,
486
-
prefill_chunk_size: int,
487
488
):
488
489
# select prompts that fit the batch size criteria
489
490
valid_prompts= []
@@ -500,7 +501,6 @@ def get_valid_prompts(
500
501
dataset_path=dataset_path,
501
502
allow_truncation=allow_truncation,
502
503
enforce_sizes=enforce_sizes,
503
-
pad_multiple=pad_multiple,
504
504
)
505
505
valid_prompts= [
506
506
(
@@ -550,28 +550,14 @@ def get_valid_prompts(
550
550
# if there does not exist enough sequence sizes between this range, we will cycle back to the beginning
551
551
# in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user
552
552
enforce_sizes= [valid_prompt_shape[1]]
553
-
if (
554
-
enforce_homogeneous_prompt_programs
555
-
orprefill_chunk_size>0
556
-
):
557
-
# if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
558
-
tkv_cutoff= (
559
-
1<< (valid_prompt_shape[1].bit_length() -1)
560
-
ifenforce_homogeneous_prompt_programs
561
-
elsepad_multiple
562
-
)
553
+
ifenforce_homogeneous_prompt_programs:
554
+
# this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
# favor sequences that are close to the valid prompt length
570
560
possible_seq_lengths.reverse()
571
-
# add the valid prompt size to the end since it will already exist in the above enforce_sizes
572
-
possible_seq_lengths=possible_seq_lengths+ [
573
-
valid_prompt_shape[1]
574
-
]
575
561
enforce_sizes=enforce_sizes+list(
576
562
itertools.islice(
577
563
itertools.cycle(possible_seq_lengths),
@@ -587,7 +573,6 @@ def get_valid_prompts(
587
573
dataset_path=dataset_path,
588
574
allow_truncation=allow_truncation,
589
575
enforce_sizes=enforce_sizes,
590
-
pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size)
591
576
)
592
577
valid_prompts.append(
593
578
(
@@ -783,6 +768,7 @@ def main():
783
768
784
769
is_fp8="fp8"inargs.attention_type
785
770
CPU_DTYPE="fp8"ifis_fp8else"fp32"
771
+
PAD_MULTIPLE=64
786
772
787
773
torch.manual_seed(42)
788
774
torch.set_grad_enabled(False)
@@ -838,13 +824,7 @@ def main():
838
824
# warmup with any input so compiler produces criteria json
839
825
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
0 commit comments