diff --git a/aiu_fms_testing_utils/testing/utils.py b/aiu_fms_testing_utils/testing/utils.py new file mode 100644 index 00000000..72fd30b2 --- /dev/null +++ b/aiu_fms_testing_utils/testing/utils.py @@ -0,0 +1,22 @@ +from collections.abc import Iterable + + +def format_kwargs_to_string(**kwargs): + """ + Turns kwargs into a str with variable names using `-`, variables separated by `_` and iterable separated by `,` + """ + formatted_pairs = [] + for key, value in sorted(kwargs.items()): + formatted_value = None + if isinstance(value, str): + formatted_value = value + elif isinstance(value, Iterable): + formatted_value = ",".join(map(str, value)) + elif value: + formatted_value = str(value) + # only append if formatted_value exists + if formatted_value: + # Keep previous convention of variable names with `-` instead of `_` + formatted_pairs.append(f"{key.replace('_', '-')}-{formatted_value}") + + return "_".join(formatted_pairs) diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index 0c655ff5..ad1b5906 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -5,6 +5,9 @@ from aiu_fms_testing_utils.utils.aiu_setup import dprint from aiu_fms_testing_utils._version import version_tuple import os +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string + +import hashlib class LogitsExtractorHook( @@ -132,6 +135,7 @@ def get_default_validation_prefix( dtype: str, attn_type: str, aftu_version: str, + **kwargs, ): """ Args: @@ -144,9 +148,17 @@ def get_default_validation_prefix( aftu_version (str): introduced in v0.3.0 to track changed in log Returns: - str: A prefix that will be prepended to the file name + str: A hashed prefix that will be prepended to the file name """ - return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}.{aftu_version}" + kwargs_str = format_kwargs_to_string(**kwargs) + + if kwargs_str == "": + filename = f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}" + else: + filename = f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}_{kwargs_str}" + hash_object = hashlib.sha256(filename.encode("utf-8")) + hex_digest = hash_object.hexdigest() + return f"{hex_digest}_{aftu_version}" def load_validation_information( @@ -416,11 +428,14 @@ def get_validation_info_path( aftu_version: Optional[Tuple[int, int, int]] = None, device_type: str = "cpu", dtype: str = "fp16", + **kwargs, ): if aftu_version is None: aftu_version = version_tuple - validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out" + sample_key = kwargs.get("sample_key", None) + + validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" full_path = os.path.join(validation_info_dir, validation_file_name) return full_path @@ -452,10 +467,12 @@ def find_validation_info_path( version_allow_decrement: bool = False, device_type: str = "cpu", dtype: str = "fp16", + **kwargs, ): """ Find the validation info path if it exists, otherwise return None """ + sample_key = kwargs.get("sample_key", None) if aftu_version is None: loc_version_tuple = version_tuple[:3] @@ -476,6 +493,7 @@ def find_validation_info_path( loc_version_tuple, device_type, dtype, + sample_key=sample_key, ) # if the path is found, we are done searching and can return if os.path.exists(full_path): diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 65a0f9ab..6615c5c9 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -11,6 +11,7 @@ from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string from fms.utils.generation import pad_input_ids import torch @@ -482,6 +483,7 @@ def sample_rag_factoid_requests( enforce_sizes: List[int] = [], truncation: bool = False, pad_multiple: int = 64, + return_key: bool = False, ) -> List[Tuple[str, int]]: if not os.path.exists(dataset_path): print("error dataset does not exist") @@ -492,7 +494,7 @@ def sample_rag_factoid_requests( for line in f: dataset.append(line) - return __sample_requests( + sample_request = __sample_requests( dataset, num_requests, tokenizer, @@ -506,6 +508,24 @@ def sample_rag_factoid_requests( _cached_dataset_key=dataset_path, ) + if return_key: + sample_key: str = format_kwargs_to_string( + dataset="rag_factoid", + num_requests=num_requests, + tokenizer=tokenizer.name_or_path.replace("/", "--"), + prompt_length_min=prompt_length_min, + prompt_length_max=prompt_length_max, + seed=seed, + enforce_heterogeneous=enforce_heterogeneous, + enforce_sizes=enforce_sizes, + truncate=truncation, + pad_multiple=pad_multiple, + ) + + return sample_request, sample_key + else: + return sample_request + def sample_sharegpt_requests( dataset_path: str, @@ -518,6 +538,7 @@ def sample_sharegpt_requests( enforce_sizes: List[int] | None = None, truncation: bool = False, pad_multiple: int = 64, + return_key: bool = False, ) -> List[Tuple[str, int]]: if not os.path.exists(dataset_path): print("downloading share-gpt dataset as it does not exist") @@ -543,7 +564,7 @@ def sample_sharegpt_requests( dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset: List[str] = [data["conversations"][0]["value"] for data in dataset] - return __sample_requests( + sample_request = __sample_requests( dataset, num_requests, tokenizer, @@ -557,6 +578,23 @@ def sample_sharegpt_requests( _cached_dataset_key=dataset_path, ) + if return_key: + sample_key: str = format_kwargs_to_string( + dataset="sharegpt", + num_requests=num_requests, + tokenizer=tokenizer.name_or_path.replace("/", "--"), + prompt_length_min=prompt_length_min, + prompt_length_max=prompt_length_max, + seed=seed, + enforce_heterogeneous=enforce_heterogeneous, + enforce_sizes=enforce_sizes, + truncate=truncation, + pad_multiple=pad_multiple, + ) + return sample_request, sample_key + else: + return sample_request + def sample_squad_v2_qa_requests( dataset_path: str, @@ -569,6 +607,7 @@ def sample_squad_v2_qa_requests( enforce_sizes: List[int] | None = None, truncation: bool = False, pad_multiple: int = 64, + return_key: bool = False, ) -> List[Tuple[str, int]]: from datasets import load_dataset @@ -582,7 +621,7 @@ def sample_squad_v2_qa_requests( ds = [f"{data['context']}\n{data['question']}" for data in ds] - return __sample_requests( + sample_request = __sample_requests( ds, num_requests, tokenizer, @@ -595,6 +634,23 @@ def sample_squad_v2_qa_requests( pad_multiple, ) + if return_key: + sample_key: str = format_kwargs_to_string( + dataset="squad_v2", + num_requests=num_requests, + tokenizer=tokenizer.name_or_path.replace("/", "--"), + prompt_length_min=prompt_length_min, + prompt_length_max=prompt_length_max, + seed=seed, + enforce_heterogeneous=enforce_heterogeneous, + enforce_sizes=enforce_sizes, + truncate=truncation, + pad_multiple=pad_multiple, + ) + return sample_request, sample_key + else: + return sample_request + def prepare_inputs( batch_size, seq_length, tokenizer, ds_path, seed=0, ds_type="sharegpt" diff --git a/scripts/drive_paged_programs.py b/scripts/drive_paged_programs.py index ea51bad8..033a8efe 100644 --- a/scripts/drive_paged_programs.py +++ b/scripts/drive_paged_programs.py @@ -40,6 +40,7 @@ get_programs_prompts, KVCACHE_NUM_BLOCKS_HINT, ) +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string parser = argparse.ArgumentParser( description="Script which will drive paged programs for debugging" @@ -195,6 +196,10 @@ custom_shape = (len(result), max([_[1] for _ in result])) def __custom_line_sampler(*args, **kwargs): + return_key = kwargs.get("return_key", False) + sample_key = format_kwargs_to_string(**kwargs) + if return_key: + return result, sample_key return result sampler = __custom_line_sampler @@ -245,7 +250,7 @@ def __custom_line_sampler(*args, **kwargs): def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0): start = time.time() - prompts_and_sizes = sampler( + prompts_and_sizes, sample_key = sampler( DATASET_PATH, batch_size, tokenizer, @@ -254,6 +259,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0 seed, enforce_sizes=enforce_sizes, truncation=allow_truncation, + return_key=True, ) end = time.time() if local_rank == 0: @@ -274,7 +280,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0 input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) - return input_ids, extra_kwargs + return input_ids, extra_kwargs, sample_key def __maybe_prepare_fp8_weights(model_in, is_fp8): @@ -296,7 +302,9 @@ def __load_validation_info( tokenizer, seed, attn_type: str, + **kwargs, ): + sample_key = kwargs.get("sample_key", None) full_path = find_validation_info_path( args.validation_info_outputs_dir, model_variant, @@ -307,6 +315,7 @@ def __load_validation_info( attn_type, version_allow_decrement=True, dtype=CPU_DTYPE, + sample_key=sample_key, ) if full_path is not None: dprint(f"cpu validation info found for seed={seed} -- loading it") @@ -367,13 +376,14 @@ def __load_validation_info( # warmup with any input so compiler produces criteria json # TODO: Swap this with __prepare_inputs once fix for shape_id is available -# input_ids, extra_kwargs = __prepare_inputs(2, max_tkv, tokenizer) +# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) prompt_list = [torch.arange(0, 64, dtype=torch.int64)] # matching vllm warmup to pad to 2 on fp8, and no pad for fp16 if is_fp8: prompt_list = prompt_list * 2 input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64) extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) + extra_kwargs["attn_name"] = ATTN_NAME if ( "granite-3.3-8b-instruct" in model_variant @@ -494,7 +504,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: for valid_prompt_shape in valid_prompt_shapes: if valid_prompt_shape == custom_shape: enforce_sizes = [valid_prompt_shape[1]] - input_ids, extra_kwargs = __prepare_inputs( + input_ids, extra_kwargs, sample_key = __prepare_inputs( valid_prompt_shape[0], valid_prompt_shape[1], tokenizer, @@ -506,6 +516,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: custom_shape, input_ids, extra_kwargs, + sample_key, ) ] break @@ -566,7 +577,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: ) ) try: - input_ids, extra_kwargs = __prepare_inputs( + input_ids, extra_kwargs, sample_key = __prepare_inputs( valid_prompt_shape[0], valid_prompt_shape[1], tokenizer, @@ -578,6 +589,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: valid_prompt_shape, input_ids, extra_kwargs, + sample_key, ) ) used_keys.add(program_seq_key[0]) @@ -609,7 +621,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): failed_cases = [] # for each program and valid prompt (batch size, sequence length) -for program_id, valid_prompt, input_ids, extra_kwargs in valid_prompts: +for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: extra_kwargs["attn_name"] = ATTN_NAME if ( "granite-3.3-8b-instruct" in model_variant @@ -634,6 +646,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): tokenizer, seed=0, attn_type=ATTN_NAME, + sample_key=sample_key, ) # if the cpu validation info is not yet computed, compute it if cpu_validation_info is None: @@ -657,6 +670,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): 0, ATTN_NAME, dtype=CPU_DTYPE, + sample_key=sample_key, ) ) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 122c9664..d257c44c 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -364,7 +364,13 @@ def __filter_before_eos(metrics, filter_indexes): def __load_validation_info( - model_path, batch_size, seq_length, max_new_tokens, tokenizer, seed, attn_type: str + model_path, + batch_size, + seq_length, + max_new_tokens, + tokenizer, + seed, + attn_type: str, ): # if path doesn't exist and paged isn't in the attention name, remove `attn_type` and recheck again, warn that we will no longer in the future have paths without 'attn_type' full_path = find_validation_info_path( diff --git a/tests/testing/test_validation.py b/tests/testing/test_validation.py index ac3367ae..220f89e9 100644 --- a/tests/testing/test_validation.py +++ b/tests/testing/test_validation.py @@ -8,7 +8,14 @@ get_validation_info_path, find_validation_info_path, __decrement_version, + get_default_validation_prefix, ) +import hashlib +import os +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string +from aiu_fms_testing_utils.utils import sample_sharegpt_requests +from transformers import AutoTokenizer + from aiu_fms_testing_utils._version import version_tuple from fms.models import get_model from fms.utils.generation import pad_input_ids @@ -73,12 +80,21 @@ def test_validation_info_round_trip(validation_type, post_iteration_hook): def test_get_validation_info_path(tmp_path): + check_pathname = "ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa" + hash_object = hashlib.sha256(check_pathname.encode("utf-8")) + hex_digest = hash_object.hexdigest() + assert ( get_validation_info_path( tmp_path, "ibm-granite/granite-3.3-8b-instruct", 4, 64, 128, 0, "sdpa" ) - == f"{tmp_path}/ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa.{'.'.join([str(_) for _ in version_tuple[:3]])}.cpu_validation_info.0.out" + == f"{tmp_path}/{hex_digest}_{'.'.join([str(_) for _ in version_tuple[:3]])}.cpu_validation_info.0.out" ) + + check_pathname = "ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa" + hash_object = hashlib.sha256(check_pathname.encode("utf-8")) + hex_digest = hash_object.hexdigest() + assert ( get_validation_info_path( tmp_path, @@ -90,7 +106,7 @@ def test_get_validation_info_path(tmp_path): "sdpa", aftu_version=(1, 2, 3), ) - == f"{tmp_path}/ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa.1.2.3.cpu_validation_info.0.out" + == f"{tmp_path}/{hex_digest}_1.2.3.cpu_validation_info.0.out" ) @@ -238,3 +254,81 @@ def test_decrement_version(max_minor, max_patch, current_version): + patch + 1 ) + + +def test_format_kwargs_to_string(): + kwargs = { + "enforce_sizes": [1, 32, 4, 8], + "batch_size": 1, + "model_id": "granite-3.3-8b", + "seq_len": 64, + } + kwargs_str = format_kwargs_to_string(**kwargs) + assert ( + kwargs_str + == "batch-size-1_enforce-sizes-1,32,4,8_model-id-granite-3.3-8b_seq-len-64" + ) + + +DATASET_PATH = os.getenv( + "DATASET_PATH", "/mnt/home/models/ShareGPT_V3_unfiltered_cleaned_split.json" +) +TOKENIZER = os.getenv("TOKENIZER", "ibm-granite/granite-3.3-8b-Instruct") + + +@pytest.mark.parametrize( + "model_variant,max_new_tokens,batch_size,seq_length,dtype,attn_type,device_type,seed,aftu_version", + [("granite-3.3-8b", 64, 2, 64, "fp16", "spda", "cpu", 0, (1, 2, 3))], +) +def test_get_default_validation_prefix( + model_variant, + max_new_tokens, + batch_size, + seq_length, + dtype, + attn_type, + device_type, + seed, + aftu_version, +): + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) + + sample_key = None + # get_default_validation_prefix with sample_key set to None + check_prefix_sample_key_none = f"{model_variant}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}" + hash_object = hashlib.sha256(check_prefix_sample_key_none.encode("utf-8")) + hex_digest = hash_object.hexdigest() + prefix_sample_key_none = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" + + assert prefix_sample_key_none == f"{hex_digest}_1.2.3.cpu_validation_info.0.out" + + # get_default_validation_prefix with no kwargs using legacy case + legacy_prefix = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out" + assert prefix_sample_key_none == legacy_prefix + + # retrieve a sample_key with return_key is True + dataset_1, sample_key = sample_sharegpt_requests( + DATASET_PATH, + batch_size, + tokenizer, + 32, + seq_length * 2, + seed=seed, + enforce_sizes=[], + return_key=True, + ) + + # Check sample key sorted by parameter name + assert sample_key.split("_") == sorted(sample_key.split("_")) + + dataset_2 = sample_sharegpt_requests( + DATASET_PATH, + batch_size, + tokenizer, + 32, + seq_length * 2, + seed=seed, + enforce_sizes=[], + ) + + assert dataset_1 == dataset_2