Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tests/artifacts/predefined_data_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

### Constants used for data
PREDEFINED_DATA_CONFIGS = os.path.join(os.path.dirname(__file__))
APPLY_CUSTOM_TEMPLATE_YAML = os.path.join(
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "apply_custom_template.yaml"
)
PRETOKENIZE_JSON_DATA_YAML = os.path.join(
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "pretokenized_json_data.yaml"
)
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking.yaml"
)
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling.yaml"
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ datasets:
batched: false
fn_kwargs:
dataset_text_field: "dataset_text_field"
dataset_template: "dataset_template"
template: "dataset_template"
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
dataprocessor:
type: default
sampling_stopping_strategy: first_exhausted
seed: 66
datasets:
- name: dataset_1
sampling: 0.3
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_field_name: input
output_field_name: output
- name: dataset_2
sampling: 0.4
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_field_name: input
output_field_name: output
- name: dataset_3
sampling: 0.3
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_field_name: input
output_field_name: output
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ datasets:
remove_columns: all
batched: false
fn_kwargs:
input_field: "INPUT"
output_field: "OUTPUT"
input_field_name: "INPUT"
output_field_name: "OUTPUT"
129 changes: 116 additions & 13 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@

# First Party
from tests.artifacts.predefined_data_configs import (
APPLY_CUSTOM_TEMPLATE_YAML,
PRETOKENIZE_JSON_DATA_YAML,
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
)
from tests.artifacts.testdata import (
MODEL_NAME,
Expand Down Expand Up @@ -428,22 +429,22 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
@pytest.mark.parametrize(
"data_config_path, data_path",
[
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
(
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
(DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
(
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
),
(
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
),
(
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
),
],
Expand Down Expand Up @@ -709,3 +710,105 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname):
with open(datafile, "r") as file:
data = json.load(file)
assert len(train_dataset) == len(data)


@pytest.mark.parametrize(
"datafiles, sampling, datasetconfigname",
[
(
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
[0.3, None, 0.3],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
),
(
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
[0.3, 0.5, 0.3],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
),
],
)
def test_process_dataset_configs_with_sampling_error(
datafiles, sampling, datasetconfigname
):

data_args = configs.DataArguments()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TRAIN_ARGS = configs.TrainingArguments(
packing=False,
max_seq_length=1024,
output_dir="tmp", # Not needed but positional
)

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
with open(datasetconfigname, "r") as f:
data = yaml.safe_load(f)
datasets = data["datasets"]
for i in range(len(datasets)):
d = datasets[i]
d["data_paths"][0] = datafiles[i]
d["sampling"] = sampling[i]
yaml.dump(data, temp_yaml_file)
data_args.data_config_path = temp_yaml_file.name

with pytest.raises(ValueError):
(_, _, _, _, _, _) = process_dataargs(
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
)


@pytest.mark.parametrize(
"datafiles, datasetconfigname",
[
(
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
],
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
),
],
)
def test_process_dataset_configs_with_sampling(datafiles, datasetconfigname):

data_args = configs.DataArguments()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TRAIN_ARGS = configs.TrainingArguments(
packing=False,
max_seq_length=1024,
output_dir="tmp", # Not needed but positional
)

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
with open(datasetconfigname, "r") as f:
data = yaml.safe_load(f)
datasets = data["datasets"]
for i in range(len(datasets)):
d = datasets[i]
d["data_paths"][0] = datafiles[i]
yaml.dump(data, temp_yaml_file)
data_args.data_config_path = temp_yaml_file.name

(train_set, eval_set, _, _, _, _) = process_dataargs(
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
)

assert isinstance(train_set, Dataset)
if eval_set:
assert isinstance(eval_set, Dataset)

assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
if eval_set:
assert set(["input_ids", "labels"]).issubset(set(eval_set.column_names))
39 changes: 27 additions & 12 deletions tuning/data/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ class DataHandlerConfig:
class DataSetConfig:
name: str
data_paths: List[str]
sampling: Optional[Dict] = None
sampling: Optional[float] = None
data_handlers: Optional[List[DataHandlerConfig]] = None


@dataclass
class DataPreProcessorConfig:
type: Optional[str] = "default"
sampling_stopping_strategy: Optional[str] = "all_exhausted"
# Default seed is not none to ensure reproducability
sampling_seed: Optional[float] = 42


@dataclass
Expand Down Expand Up @@ -84,17 +87,12 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
)
p = _p
c.data_paths.append(p)
if "sampling" in kwargs:
sampling_kwargs = kwargs["sampling"]
assert isinstance(
dict, sampling_kwargs
), "sampling arguments should be of the type dict"
if "ratio" in sampling_kwargs:
ratio = sampling_kwargs["ratio"]
assert isinstance(ratio, float) and (
0 <= ratio <= 1.0
), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]"
c.sampling = sampling_kwargs
if "sampling" in kwargs and kwargs["sampling"] is not None:
ratio = kwargs["sampling"]
assert isinstance(ratio, float) and (
0 <= ratio <= 1.0
), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]"
c.sampling = ratio
if "data_handlers" in kwargs:
c.data_handlers = []
for handler in kwargs["data_handlers"]:
Expand All @@ -106,6 +104,23 @@ def _validate_dataprocessor_config(dataprocessor_config) -> DataPreProcessorConf
kwargs = dataprocessor_config
c = DataPreProcessorConfig()
assert isinstance(kwargs, dict), "dataprocessor in data_config needs to be a dict"
if "type" in kwargs:
assert isinstance(kwargs["type"], str), "dataprocessor type must be a string"
c.type = kwargs["type"]
if "sampling_stopping_strategy" in kwargs:
strategy = kwargs["sampling_stopping_strategy"]
assert isinstance(
strategy, str
), "dataset sampling stopping strategy must be a string"
assert strategy in [
"first_exhausted",
"all_exhausted",
], "allowed sampling stopping strategies are all_exhausted(default) or first_exhausted"
c.sampling_stopping_strategy = strategy
if "sampling_seed" in kwargs:
seed = kwargs["sampling_seed"]
assert isinstance(seed, int), "sampling seed should be int"
c.sampling_seed = seed
return c


Expand Down
Loading
Loading