Skip to content
22 changes: 22 additions & 0 deletions aiu_fms_testing_utils/testing/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from collections.abc import Iterable


def format_kwargs_to_string(**kwargs):
"""
Turns kwargs into a str with variable names using `-`, variables separated by `_` and iterable separated by `,`
"""
formatted_pairs = []
for key, value in sorted(kwargs.items()):
formatted_value = None
if isinstance(value, str):
formatted_value = value
elif isinstance(value, Iterable):
formatted_value = ",".join(map(str, value))
elif value:
formatted_value = str(value)
# only append if formatted_value exists
if formatted_value:
# Keep previous convention of variable names with `-` instead of `_`
formatted_pairs.append(f"{key.replace('_', '-')}-{formatted_value}")

return "_".join(formatted_pairs)
24 changes: 21 additions & 3 deletions aiu_fms_testing_utils/testing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from aiu_fms_testing_utils.utils.aiu_setup import dprint
from aiu_fms_testing_utils._version import version_tuple
import os
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string

import hashlib


class LogitsExtractorHook(
Expand Down Expand Up @@ -132,6 +135,7 @@ def get_default_validation_prefix(
dtype: str,
attn_type: str,
aftu_version: str,
**kwargs,
):
"""
Args:
Expand All @@ -144,9 +148,17 @@ def get_default_validation_prefix(
aftu_version (str): introduced in v0.3.0 to track changed in log

Returns:
str: A prefix that will be prepended to the file name
str: A hashed prefix that will be prepended to the file name
"""
return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}.{aftu_version}"
kwargs_str = format_kwargs_to_string(**kwargs)

if kwargs_str == "":
filename = f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}"
else:
filename = f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}_{kwargs_str}"
hash_object = hashlib.sha256(filename.encode("utf-8"))
hex_digest = hash_object.hexdigest()
return f"{hex_digest}_{aftu_version}"


def load_validation_information(
Expand Down Expand Up @@ -416,11 +428,14 @@ def get_validation_info_path(
aftu_version: Optional[Tuple[int, int, int]] = None,
device_type: str = "cpu",
dtype: str = "fp16",
**kwargs,
):
if aftu_version is None:
aftu_version = version_tuple

validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out"
sample_key = kwargs.get("sample_key", None)

validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out"
full_path = os.path.join(validation_info_dir, validation_file_name)
return full_path

Expand Down Expand Up @@ -452,10 +467,12 @@ def find_validation_info_path(
version_allow_decrement: bool = False,
device_type: str = "cpu",
dtype: str = "fp16",
**kwargs,
):
"""
Find the validation info path if it exists, otherwise return None
"""
sample_key = kwargs.get("sample_key", None)

if aftu_version is None:
loc_version_tuple = version_tuple[:3]
Expand All @@ -476,6 +493,7 @@ def find_validation_info_path(
loc_version_tuple,
device_type,
dtype,
sample_key=sample_key,
)
# if the path is found, we are done searching and can return
if os.path.exists(full_path):
Expand Down
62 changes: 59 additions & 3 deletions aiu_fms_testing_utils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string

from fms.utils.generation import pad_input_ids
import torch
Expand Down Expand Up @@ -482,6 +483,7 @@ def sample_rag_factoid_requests(
enforce_sizes: List[int] = [],
truncation: bool = False,
pad_multiple: int = 64,
return_key: bool = False,
) -> List[Tuple[str, int]]:
if not os.path.exists(dataset_path):
print("error dataset does not exist")
Expand All @@ -492,7 +494,7 @@ def sample_rag_factoid_requests(
for line in f:
dataset.append(line)

return __sample_requests(
sample_request = __sample_requests(
dataset,
num_requests,
tokenizer,
Expand All @@ -506,6 +508,24 @@ def sample_rag_factoid_requests(
_cached_dataset_key=dataset_path,
)

if return_key:
sample_key: str = format_kwargs_to_string(
dataset="rag_factoid",
num_requests=num_requests,
tokenizer=tokenizer.name_or_path.replace("/", "--"),
prompt_length_min=prompt_length_min,
prompt_length_max=prompt_length_max,
seed=seed,
enforce_heterogeneous=enforce_heterogeneous,
enforce_sizes=enforce_sizes,
truncate=truncation,
pad_multiple=pad_multiple,
)

return sample_request, sample_key
else:
return sample_request


def sample_sharegpt_requests(
dataset_path: str,
Expand All @@ -518,6 +538,7 @@ def sample_sharegpt_requests(
enforce_sizes: List[int] | None = None,
truncation: bool = False,
pad_multiple: int = 64,
return_key: bool = False,
) -> List[Tuple[str, int]]:
if not os.path.exists(dataset_path):
print("downloading share-gpt dataset as it does not exist")
Expand All @@ -543,7 +564,7 @@ def sample_sharegpt_requests(
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
dataset: List[str] = [data["conversations"][0]["value"] for data in dataset]

return __sample_requests(
sample_request = __sample_requests(
dataset,
num_requests,
tokenizer,
Expand All @@ -557,6 +578,23 @@ def sample_sharegpt_requests(
_cached_dataset_key=dataset_path,
)

if return_key:
sample_key: str = format_kwargs_to_string(
dataset="sharegpt",
num_requests=num_requests,
tokenizer=tokenizer.name_or_path.replace("/", "--"),
prompt_length_min=prompt_length_min,
prompt_length_max=prompt_length_max,
seed=seed,
enforce_heterogeneous=enforce_heterogeneous,
enforce_sizes=enforce_sizes,
truncate=truncation,
pad_multiple=pad_multiple,
)
return sample_request, sample_key
else:
return sample_request


def sample_squad_v2_qa_requests(
dataset_path: str,
Expand All @@ -569,6 +607,7 @@ def sample_squad_v2_qa_requests(
enforce_sizes: List[int] | None = None,
truncation: bool = False,
pad_multiple: int = 64,
return_key: bool = False,
) -> List[Tuple[str, int]]:
from datasets import load_dataset

Expand All @@ -582,7 +621,7 @@ def sample_squad_v2_qa_requests(

ds = [f"{data['context']}\n{data['question']}" for data in ds]

return __sample_requests(
sample_request = __sample_requests(
ds,
num_requests,
tokenizer,
Expand All @@ -595,6 +634,23 @@ def sample_squad_v2_qa_requests(
pad_multiple,
)

if return_key:
sample_key: str = format_kwargs_to_string(
dataset="squad_v2",
num_requests=num_requests,
tokenizer=tokenizer.name_or_path.replace("/", "--"),
prompt_length_min=prompt_length_min,
prompt_length_max=prompt_length_max,
seed=seed,
enforce_heterogeneous=enforce_heterogeneous,
enforce_sizes=enforce_sizes,
truncate=truncation,
pad_multiple=pad_multiple,
)
return sample_request, sample_key
else:
return sample_request


def prepare_inputs(
batch_size, seq_length, tokenizer, ds_path, seed=0, ds_type="sharegpt"
Expand Down
26 changes: 20 additions & 6 deletions scripts/drive_paged_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_programs_prompts,
KVCACHE_NUM_BLOCKS_HINT,
)
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string

parser = argparse.ArgumentParser(
description="Script which will drive paged programs for debugging"
Expand Down Expand Up @@ -195,6 +196,10 @@
custom_shape = (len(result), max([_[1] for _ in result]))

def __custom_line_sampler(*args, **kwargs):
return_key = kwargs.get("return_key", False)
sample_key = format_kwargs_to_string(**kwargs)
if return_key:
return result, sample_key
return result

sampler = __custom_line_sampler
Expand Down Expand Up @@ -245,7 +250,7 @@ def __custom_line_sampler(*args, **kwargs):

def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0):
start = time.time()
prompts_and_sizes = sampler(
prompts_and_sizes, sample_key = sampler(
DATASET_PATH,
batch_size,
tokenizer,
Expand All @@ -254,6 +259,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0
seed,
enforce_sizes=enforce_sizes,
truncation=allow_truncation,
return_key=True,
)
end = time.time()
if local_rank == 0:
Expand All @@ -274,7 +280,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0

input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
return input_ids, extra_kwargs
return input_ids, extra_kwargs, sample_key


def __maybe_prepare_fp8_weights(model_in, is_fp8):
Expand All @@ -296,7 +302,9 @@ def __load_validation_info(
tokenizer,
seed,
attn_type: str,
**kwargs,
):
sample_key = kwargs.get("sample_key", None)
full_path = find_validation_info_path(
args.validation_info_outputs_dir,
model_variant,
Expand All @@ -307,6 +315,7 @@ def __load_validation_info(
attn_type,
version_allow_decrement=True,
dtype=CPU_DTYPE,
sample_key=sample_key,
)
if full_path is not None:
dprint(f"cpu validation info found for seed={seed} -- loading it")
Expand Down Expand Up @@ -367,13 +376,14 @@ def __load_validation_info(

# warmup with any input so compiler produces criteria json
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
# input_ids, extra_kwargs = __prepare_inputs(2, max_tkv, tokenizer)
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
if is_fp8:
prompt_list = prompt_list * 2
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64)
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)

extra_kwargs["attn_name"] = ATTN_NAME
if (
"granite-3.3-8b-instruct" in model_variant
Expand Down Expand Up @@ -494,7 +504,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
for valid_prompt_shape in valid_prompt_shapes:
if valid_prompt_shape == custom_shape:
enforce_sizes = [valid_prompt_shape[1]]
input_ids, extra_kwargs = __prepare_inputs(
input_ids, extra_kwargs, sample_key = __prepare_inputs(
valid_prompt_shape[0],
valid_prompt_shape[1],
tokenizer,
Expand All @@ -506,6 +516,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
custom_shape,
input_ids,
extra_kwargs,
sample_key,
)
]
break
Expand Down Expand Up @@ -566,7 +577,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
)
)
try:
input_ids, extra_kwargs = __prepare_inputs(
input_ids, extra_kwargs, sample_key = __prepare_inputs(
valid_prompt_shape[0],
valid_prompt_shape[1],
tokenizer,
Expand All @@ -578,6 +589,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
valid_prompt_shape,
input_ids,
extra_kwargs,
sample_key,
)
)
used_keys.add(program_seq_key[0])
Expand Down Expand Up @@ -609,7 +621,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):

failed_cases = []
# for each program and valid prompt (batch size, sequence length)
for program_id, valid_prompt, input_ids, extra_kwargs in valid_prompts:
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
extra_kwargs["attn_name"] = ATTN_NAME
if (
"granite-3.3-8b-instruct" in model_variant
Expand All @@ -634,6 +646,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
tokenizer,
seed=0,
attn_type=ATTN_NAME,
sample_key=sample_key,
)
# if the cpu validation info is not yet computed, compute it
if cpu_validation_info is None:
Expand All @@ -657,6 +670,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
0,
ATTN_NAME,
dtype=CPU_DTYPE,
sample_key=sample_key,
)
)

Expand Down
8 changes: 7 additions & 1 deletion tests/models/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,13 @@ def __filter_before_eos(metrics, filter_indexes):


def __load_validation_info(
model_path, batch_size, seq_length, max_new_tokens, tokenizer, seed, attn_type: str
model_path,
batch_size,
seq_length,
max_new_tokens,
tokenizer,
seed,
attn_type: str,
):
# if path doesn't exist and paged isn't in the attention name, remove `attn_type` and recheck again, warn that we will no longer in the future have paths without 'attn_type'
full_path = find_validation_info_path(
Expand Down
Loading