Skip to content

Commit ac148eb

Browse files
dushyantbehlwillmj
andcommitted
Update tuning/data/data_processors.py
Co-authored-by: Will Johnson <[email protected]> Signed-off-by: Dushyant Behl <[email protected]>
1 parent 251902b commit ac148eb

File tree

3 files changed

+20
-27
lines changed

3 files changed

+20
-27
lines changed

tuning/data/data_handlers.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Definition of some predefined data preprocessing functions that we need.
1616

1717
# Standard
18-
from typing import Dict, List
18+
from typing import Dict
1919

2020
# Third Party
2121
from transformers import AutoTokenizer
@@ -56,12 +56,12 @@ def tokenize_and_apply_input_masking(
5656
def apply_dataset_formatting(
5757
element: Dict[str, str], tokenizer: AutoTokenizer, dataset_text_field: str, **kwargs
5858
):
59-
if isinstance(element[dataset_text_field], list): # batched = True
60-
return {
61-
f"{dataset_text_field}": [
62-
text + tokenizer.eos_token for text in element[f"{dataset_text_field}"]
63-
]
64-
}
59+
if isinstance(element[dataset_text_field], list): # batched = True
60+
return {
61+
f"{dataset_text_field}": [
62+
text + tokenizer.eos_token for text in element[f"{dataset_text_field}"]
63+
]
64+
}
6565
return {
6666
f"{dataset_text_field}": element[f"{dataset_text_field}"] + tokenizer.eos_token
6767
}
@@ -77,13 +77,11 @@ def apply_custom_data_formatting_template(
7777
template += tokenizer.eos_token
7878

7979
# TODO: Eventually move the code here.
80-
return custom_data_formatter(
81-
element=element, formatted_dataset_field=dataset_text_field, template=template
82-
)
80+
return custom_data_formatter(element, template, dataset_text_field)
8381

8482

8583
AVAILABLE_DATA_HANDLERS = {
8684
"tokenize_and_apply_instruction_masking": tokenize_and_apply_input_masking,
8785
"apply_dataset_formatting": apply_dataset_formatting,
88-
"apply_custom_data_formatting_template": apply_dataset_formatting,
86+
"apply_custom_data_formatting_template": apply_custom_data_formatting_template,
8987
}

tuning/data/data_processors.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020

2121
# Third Party
22-
from datasets import Dataset, IterableDataset, DatasetDict
22+
from datasets import Dataset, DatasetDict, IterableDataset
2323
from datasets.exceptions import DatasetNotFoundError
2424
from transformers import AutoTokenizer
2525
import datasets
@@ -126,7 +126,7 @@ def _process_dataset_configs(
126126
if d.sampling:
127127
logging.warning("Sampling multiple datasets is not supported yet")
128128

129-
if d.data_handlers: # Execute the datahandlers
129+
if d.data_handlers: # Execute the datahandlers
130130
for data_handler in d.data_handlers:
131131
handler_name: str = data_handler.name
132132
handler: callable = self.registered_handlers[handler_name]
@@ -157,8 +157,7 @@ def _process_dataset_configs(
157157

158158
kwargs["fn_kwargs"] = dict(kwargs["fn_kwargs"], **extra_kwargs)
159159

160-
# logging.info
161-
# assert ("Applying Handler : {data_handler} Args : {kwargs}") == ""
160+
logging.info("Applying Handler : {data_handler} Args : {kwargs}")
162161

163162
raw_datasets = raw_datasets.map(handler, **kwargs)
164163

@@ -205,12 +204,12 @@ def get_dataprocessor(
205204
) -> DataPreProcessor:
206205
loader = dataloaderconfig.type
207206
if loader == "default":
208-
procesor = HFBasedDataPreProcessor(
207+
processor = HFBasedDataPreProcessor(
209208
dataloaderconfig=dataloaderconfig,
210209
tokenizer=tokenizer,
211210
accelerator=accelerator,
212211
)
213212
else:
214-
procesor = None
215-
autoregister_available_handlers(procesor)
216-
return procesor
213+
processor = None
214+
autoregister_available_handlers(processor)
215+
return processor

tuning/data/setup_dataprocessor.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,17 @@
1717
import logging
1818

1919
# Third Party
20-
from datasets import Dataset, DatasetDict
2120
from transformers import AutoTokenizer
2221

2322
# Local
2423
from tuning.config.configs import DataArguments
2524
from tuning.data.data_config import (
26-
DataConfig,
2725
DataHandlerConfig,
2826
DataLoaderConfig,
2927
DataSetConfig,
3028
load_and_validate_data_config,
3129
)
32-
from tuning.data.data_processors import DataPreProcessor, get_dataprocessor
30+
from tuning.data.data_processors import get_dataprocessor
3331
from tuning.utils.preprocessing_utils import (
3432
JSON_INPUT_KEY,
3533
JSON_OUTPUT_KEY,
@@ -45,8 +43,6 @@ def process_dataargs(
4543
if data_args.validation_data_path:
4644
validation_dataset = True
4745

48-
dataset_text_field = data_args.dataset_text_field
49-
5046
# Create a data processor with default loader config
5147
default_loader_config = DataLoaderConfig()
5248
data_processor = get_dataprocessor(
@@ -72,6 +68,8 @@ def process_dataargs(
7268
fn_kwargs = {}
7369
handlers = None
7470

71+
dataset_text_field = data_args.dataset_text_field
72+
7573
# Use case specific handlers
7674
if is_train_data_pretokenized:
7775
# dataset_text_field is irrelevant to pretokenized datasets
@@ -95,7 +93,6 @@ def process_dataargs(
9593
)
9694
handlers = [handler]
9795
else:
98-
9996
# TODO: These should be called DEFAULT in the name as they are hardcoded.
10097
fn_kwargs["input_field_name"] = JSON_INPUT_KEY
10198
fn_kwargs["output_field_name"] = JSON_OUTPUT_KEY
@@ -132,11 +129,10 @@ def process_dataargs(
132129
if validation_dataset:
133130
eval_dataset = data_processor.process_dataset_configs([eval_dataset_config])
134131
logging.info("Validation dataset length is %s", len(eval_dataset))
135-
# dataset_text_field is irrelevant to pretokenized datasets
132+
136133
return train_dataset, eval_dataset, dataset_text_field
137134

138135

139-
# TODO: This is very basic the handling of validation will be done by adding splitter.
140136
# For now assume 2 differnet arguments for training and validation dataset config files.
141137
# This is very limited but is done to keep first implementation minimal
142138
def process_dataconfig_file(dataconfigfile: str, tokenizer: AutoTokenizer):

0 commit comments

Comments
 (0)