|
15 | 15 | """Unit Tests for SFT Trainer. |
16 | 16 | """ |
17 | 17 |
|
| 18 | +# pylint: disable=too-many-lines |
| 19 | + |
18 | 20 | # Standard |
19 | 21 | import copy |
20 | 22 | import json |
21 | 23 | import os |
22 | 24 | import tempfile |
| 25 | +from dataclasses import asdict |
23 | 26 |
|
24 | 27 | # Third Party |
25 | 28 | from datasets.exceptions import DatasetGenerationError |
|
46 | 49 | from tuning import sft_trainer |
47 | 50 | from tuning.config import configs, peft_config |
48 | 51 | from tuning.config.tracker_configs import FileLoggingTrackerConfig |
| 52 | +from tuning.data.data_config import ( |
| 53 | + DataConfig, |
| 54 | + DataHandlerConfig, |
| 55 | + DataPreProcessorConfig, |
| 56 | + DataSetConfig, |
| 57 | +) |
| 58 | +from tuning.data.data_handlers import apply_dataset_formatting |
49 | 59 |
|
50 | 60 | MODEL_ARGS = configs.ModelArguments( |
51 | 61 | model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32" |
@@ -1124,3 +1134,88 @@ def test_pretokenized_dataset_wrong_format(): |
1124 | 1134 | # is essentially swallowing a KeyError here. |
1125 | 1135 | with pytest.raises(ValueError): |
1126 | 1136 | sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS) |
| 1137 | + |
| 1138 | + |
| 1139 | +@pytest.mark.parametrize( |
| 1140 | + "additional_handlers", |
| 1141 | + [ |
| 1142 | + "thisisnotokay", |
| 1143 | + [], |
| 1144 | + {lambda x: {"x": x}: "notokayeither"}, |
| 1145 | + {"thisisfine": "thisisnot"}, |
| 1146 | + ], |
| 1147 | +) |
| 1148 | +def test_run_with_bad_additional_data_handlers(additional_handlers): |
| 1149 | + with tempfile.TemporaryDirectory() as tempdir: |
| 1150 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 1151 | + train_args.output_dir = tempdir |
| 1152 | + |
| 1153 | + with pytest.raises( |
| 1154 | + ValueError, match="Handlers should be of type Dict, str to callable" |
| 1155 | + ): |
| 1156 | + sft_trainer.train( |
| 1157 | + MODEL_ARGS, |
| 1158 | + DATA_ARGS, |
| 1159 | + train_args, |
| 1160 | + PEFT_PT_ARGS, |
| 1161 | + additional_data_handlers=additional_handlers, |
| 1162 | + ) |
| 1163 | + |
| 1164 | + |
| 1165 | +def test_run_with_additional_data_handlers_as_none(): |
| 1166 | + with tempfile.TemporaryDirectory() as tempdir: |
| 1167 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 1168 | + train_args.output_dir = tempdir |
| 1169 | + |
| 1170 | + sft_trainer.train( |
| 1171 | + MODEL_ARGS, |
| 1172 | + DATA_ARGS, |
| 1173 | + train_args, |
| 1174 | + PEFT_PT_ARGS, |
| 1175 | + additional_data_handlers=None, |
| 1176 | + ) |
| 1177 | + |
| 1178 | + |
| 1179 | +def test_run_by_passing_additional_data_handlers(): |
| 1180 | + |
| 1181 | + # This is my test handler |
| 1182 | + TEST_HANDLER = "my_test_handler" |
| 1183 | + |
| 1184 | + def test_handler(element, tokenizer, **kwargs): |
| 1185 | + return apply_dataset_formatting(element, tokenizer, "custom_formatted_field") |
| 1186 | + |
| 1187 | + # This data config calls for data handler to be applied to dataset |
| 1188 | + preprocessor_config = DataPreProcessorConfig() |
| 1189 | + handler_config = DataHandlerConfig(name="my_test_handler", arguments=None) |
| 1190 | + dataaset_config = DataSetConfig( |
| 1191 | + name="test_dataset", |
| 1192 | + data_paths=TWITTER_COMPLAINTS_DATA_JSON, |
| 1193 | + data_handlers=[handler_config], |
| 1194 | + ) |
| 1195 | + data_config = DataConfig( |
| 1196 | + dataprocessor=preprocessor_config, datasets=[dataaset_config] |
| 1197 | + ) |
| 1198 | + |
| 1199 | + # dump the data config to a file, also test if json data config works |
| 1200 | + with tempfile.NamedTemporaryFile( |
| 1201 | + "w", delete=False, suffix=".json" |
| 1202 | + ) as temp_data_file: |
| 1203 | + data_config_raw = json.dumps(asdict(data_config)) |
| 1204 | + temp_data_file.write(data_config_raw) |
| 1205 | + data_config_path = temp_data_file.name |
| 1206 | + |
| 1207 | + # now launch sft trainer after registering data handler |
| 1208 | + with tempfile.TemporaryDirectory() as tempdir: |
| 1209 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 1210 | + train_args.output_dir = tempdir |
| 1211 | + data_args = copy.deepcopy(DATA_ARGS) |
| 1212 | + data_args.data_config_path = data_config_path |
| 1213 | + data_args.dataset_text_field = "custom_formatted_field" |
| 1214 | + |
| 1215 | + sft_trainer.train( |
| 1216 | + MODEL_ARGS, |
| 1217 | + DATA_ARGS, |
| 1218 | + train_args, |
| 1219 | + PEFT_PT_ARGS, |
| 1220 | + additional_data_handlers={TEST_HANDLER: test_handler}, |
| 1221 | + ) |
0 commit comments