Skip to content

Commit 4ad711c

Browse files
committed
[dpp] added return_key to samplers in [utils] to return sample_key, set to False by default, but turned on for dpp
Signed-off-by: dliu-ibm <[email protected]>
1 parent 67be471 commit 4ad711c

File tree

3 files changed

+51
-10
lines changed

3 files changed

+51
-10
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,9 +440,9 @@ def get_validation_info_path(
440440
if aftu_version is None:
441441
aftu_version = version_tuple
442442

443-
enforce_sizes = kwargs.get("enforce_sizes", None)
443+
sample_key = kwargs.get("sample_key", None)
444444

445-
validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]])), enforce_sizes}.{device_type}_validation_info.{seed}.out"
445+
validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]])), sample_key}.{device_type}_validation_info.{seed}.out"
446446
full_path = os.path.join(validation_info_dir, validation_file_name)
447447
return full_path
448448

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size
1313
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
14+
from aiu_fms_testing_utils.testing.validation import format_kwargs_to_string
1415

1516
from fms.utils.generation import pad_input_ids
1617
import torch
@@ -479,6 +480,7 @@ def sample_rag_factoid_requests(
479480
enforce_sizes: List[int] = [],
480481
truncation: bool = False,
481482
pad_multiple: int = 64,
483+
return_key: bool = False,
482484
) -> List[Tuple[str, int]]:
483485
if not os.path.exists(dataset_path):
484486
print("error dataset does not exist")
@@ -489,7 +491,7 @@ def sample_rag_factoid_requests(
489491
for line in f:
490492
dataset.append(line)
491493

492-
return __sample_requests(
494+
sample_request = __sample_requests(
493495
dataset,
494496
num_requests,
495497
tokenizer,
@@ -503,6 +505,24 @@ def sample_rag_factoid_requests(
503505
_cached_dataset_key=dataset_path,
504506
)
505507

508+
sample_key: str = format_kwargs_to_string(
509+
dataset="rag_factoid",
510+
num_requests=num_requests,
511+
tokenizer=tokenizer.name_or_path.replace("/", "--"),
512+
prompt_length_min=prompt_length_min,
513+
prompt_length_max=prompt_length_max,
514+
seed=seed,
515+
enforce_heterogeneous=enforce_heterogeneous,
516+
enforce_sizes=enforce_sizes,
517+
truncate=truncation,
518+
pad_multiple=pad_multiple,
519+
)
520+
521+
if return_key:
522+
return sample_request, sample_key
523+
else:
524+
return sample_request
525+
506526

507527
def sample_sharegpt_requests(
508528
dataset_path: str,
@@ -515,6 +535,7 @@ def sample_sharegpt_requests(
515535
enforce_sizes: List[int] | None = None,
516536
truncation: bool = False,
517537
pad_multiple: int = 64,
538+
return_key: bool = False,
518539
) -> List[Tuple[str, int]]:
519540
if not os.path.exists(dataset_path):
520541
print("downloading share-gpt dataset as it does not exist")
@@ -540,7 +561,7 @@ def sample_sharegpt_requests(
540561
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
541562
dataset: List[str] = [data["conversations"][0]["value"] for data in dataset]
542563

543-
return __sample_requests(
564+
sample_request = __sample_requests(
544565
dataset,
545566
num_requests,
546567
tokenizer,
@@ -554,6 +575,24 @@ def sample_sharegpt_requests(
554575
_cached_dataset_key=dataset_path,
555576
)
556577

578+
sample_key: str = format_kwargs_to_string(
579+
dataset="sharegpt",
580+
num_requests=num_requests,
581+
tokenizer=tokenizer.name_or_path.replace("/", "--"),
582+
prompt_length_min=prompt_length_min,
583+
prompt_length_max=prompt_length_max,
584+
seed=seed,
585+
enforce_heterogeneous=enforce_heterogeneous,
586+
enforce_sizes=enforce_sizes,
587+
truncate=truncation,
588+
pad_multiple=pad_multiple,
589+
)
590+
591+
if return_key:
592+
return sample_request, sample_key
593+
else:
594+
return sample_request
595+
557596

558597
def sample_squad_v2_qa_requests(
559598
dataset_path: str,

scripts/drive_paged_programs.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def __custom_line_sampler(*args, **kwargs):
245245

246246
def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0):
247247
start = time.time()
248-
prompts_and_sizes = sampler(
248+
prompts_and_sizes, sample_key = sampler(
249249
DATASET_PATH,
250250
batch_size,
251251
tokenizer,
@@ -254,6 +254,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0
254254
seed,
255255
enforce_sizes=enforce_sizes,
256256
truncation=allow_truncation,
257+
return_key=True,
257258
)
258259
end = time.time()
259260
if local_rank == 0:
@@ -274,7 +275,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0
274275

275276
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
276277
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
277-
return input_ids, extra_kwargs
278+
return input_ids, extra_kwargs, sample_key
278279

279280

280281
def __maybe_prepare_fp8_weights(model_in, is_fp8):
@@ -367,13 +368,14 @@ def __load_validation_info(
367368

368369
# warmup with any input so compiler produces criteria json
369370
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
370-
# input_ids, extra_kwargs = __prepare_inputs(2, max_tkv, tokenizer)
371+
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
371372
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
372373
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
373374
if is_fp8:
374375
prompt_list = prompt_list * 2
375376
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64)
376377
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
378+
377379
extra_kwargs["attn_name"] = ATTN_NAME
378380
if (
379381
"granite-3.3-8b-instruct" in model_variant
@@ -572,8 +574,8 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
572574
itertools.islice(itertools.cycle(possible_seq_lengths), valid_prompt[0] - 1)
573575
)
574576

575-
input_ids, extra_kwargs = __prepare_inputs(
576-
valid_prompt[0], valid_prompt[1], tokenizer, enforce_sizes=enforce_sizes
577+
input_ids, extra_kwargs, sample_key = __prepare_inputs(
578+
valid_prompt[0], valid_prompt[1], tokenizer, enforce_sizes=[valid_prompt[1]]
577579
)
578580
extra_kwargs["attn_name"] = ATTN_NAME
579581
if (
@@ -622,7 +624,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
622624
0,
623625
ATTN_NAME,
624626
dtype=CPU_DTYPE,
625-
enforce_sizes=[valid_prompt[1]],
627+
sample_key=sample_key,
626628
)
627629
)
628630

0 commit comments

Comments
 (0)