Skip to content

Commit 64bf80a

Browse files
committed
add unit tests for additional data handlers
Signed-off-by: Dushyant Behl <[email protected]>
1 parent cad3a2d commit 64bf80a

File tree

2 files changed

+100
-5
lines changed

2 files changed

+100
-5
lines changed

tests/test_sft_trainer.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
"""Unit Tests for SFT Trainer.
1616
"""
1717

18+
# pylint: disable=too-many-lines
19+
1820
# Standard
1921
import copy
2022
import json
2123
import os
2224
import tempfile
25+
from dataclasses import asdict
2326

2427
# Third Party
2528
from datasets.exceptions import DatasetGenerationError
@@ -46,6 +49,13 @@
4649
from tuning import sft_trainer
4750
from tuning.config import configs, peft_config
4851
from tuning.config.tracker_configs import FileLoggingTrackerConfig
52+
from tuning.data.data_config import (
53+
DataConfig,
54+
DataHandlerConfig,
55+
DataPreProcessorConfig,
56+
DataSetConfig,
57+
)
58+
from tuning.data.data_handlers import apply_dataset_formatting
4959

5060
MODEL_ARGS = configs.ModelArguments(
5161
model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32"
@@ -1124,3 +1134,88 @@ def test_pretokenized_dataset_wrong_format():
11241134
# is essentially swallowing a KeyError here.
11251135
with pytest.raises(ValueError):
11261136
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)
1137+
1138+
1139+
@pytest.mark.parametrize(
1140+
"additional_handlers",
1141+
[
1142+
"thisisnotokay",
1143+
[],
1144+
{lambda x: {"x": x}: "notokayeither"},
1145+
{"thisisfine": "thisisnot"},
1146+
],
1147+
)
1148+
def test_run_with_bad_additional_data_handlers(additional_handlers):
1149+
with tempfile.TemporaryDirectory() as tempdir:
1150+
train_args = copy.deepcopy(TRAIN_ARGS)
1151+
train_args.output_dir = tempdir
1152+
1153+
with pytest.raises(
1154+
ValueError, match="Handlers should be of type Dict, str to callable"
1155+
):
1156+
sft_trainer.train(
1157+
MODEL_ARGS,
1158+
DATA_ARGS,
1159+
train_args,
1160+
PEFT_PT_ARGS,
1161+
additional_data_handlers=additional_handlers,
1162+
)
1163+
1164+
1165+
def test_run_with_additional_data_handlers_as_none():
1166+
with tempfile.TemporaryDirectory() as tempdir:
1167+
train_args = copy.deepcopy(TRAIN_ARGS)
1168+
train_args.output_dir = tempdir
1169+
1170+
sft_trainer.train(
1171+
MODEL_ARGS,
1172+
DATA_ARGS,
1173+
train_args,
1174+
PEFT_PT_ARGS,
1175+
additional_data_handlers=None,
1176+
)
1177+
1178+
1179+
def test_run_by_passing_additional_data_handlers():
1180+
1181+
# This is my test handler
1182+
TEST_HANDLER = "my_test_handler"
1183+
1184+
def test_handler(element, tokenizer, **kwargs):
1185+
return apply_dataset_formatting(element, tokenizer, "custom_formatted_field")
1186+
1187+
# This data config calls for data handler to be applied to dataset
1188+
preprocessor_config = DataPreProcessorConfig()
1189+
handler_config = DataHandlerConfig(name="my_test_handler", arguments=None)
1190+
dataaset_config = DataSetConfig(
1191+
name="test_dataset",
1192+
data_paths=TWITTER_COMPLAINTS_DATA_JSON,
1193+
data_handlers=[handler_config],
1194+
)
1195+
data_config = DataConfig(
1196+
dataprocessor=preprocessor_config, datasets=[dataaset_config]
1197+
)
1198+
1199+
# dump the data config to a file, also test if json data config works
1200+
with tempfile.NamedTemporaryFile(
1201+
"w", delete=False, suffix=".json"
1202+
) as temp_data_file:
1203+
data_config_raw = json.dumps(asdict(data_config))
1204+
temp_data_file.write(data_config_raw)
1205+
data_config_path = temp_data_file.name
1206+
1207+
# now launch sft trainer after registering data handler
1208+
with tempfile.TemporaryDirectory() as tempdir:
1209+
train_args = copy.deepcopy(TRAIN_ARGS)
1210+
train_args.output_dir = tempdir
1211+
data_args = copy.deepcopy(DATA_ARGS)
1212+
data_args.data_config_path = data_config_path
1213+
data_args.dataset_text_field = "custom_formatted_field"
1214+
1215+
sft_trainer.train(
1216+
MODEL_ARGS,
1217+
DATA_ARGS,
1218+
train_args,
1219+
PEFT_PT_ARGS,
1220+
additional_data_handlers={TEST_HANDLER: test_handler},
1221+
)

tuning/data/data_processors.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,16 @@ def __init__(
4747
self.registered_handlers = {}
4848

4949
def register_data_handler(self, name: str, func: Callable):
50-
assert isinstance(name, str), "Handler name should be of str type"
51-
assert callable(func), "Handler should be a callable routine"
50+
if not isinstance(name, str) or not callable(func):
51+
raise ValueError("Handlers should be of type Dict, str to callable")
52+
logging.info("Registering handler %s passed by the user", name)
5253
self.registered_handlers[name] = func
5354

5455
def register_data_handlers(self, handlers: Dict[str, Callable]):
5556
if handlers is None:
5657
return
57-
assert isinstance(
58-
handlers, Dict
59-
), "Handlers should be of type Dict[str:Callable]"
58+
if not isinstance(handlers, Dict):
59+
raise ValueError("Handlers should be of type Dict, str to callable")
6060
for k, v in handlers.items():
6161
self.register_data_handler(name=k, func=v)
6262

0 commit comments

Comments
 (0)