Skip to content

Commit f1fd130

Browse files
authored
feat: Add support for renaming and retaining columns in data preprocessor (#466)
* Add functionality to rename and retain dataset columns in data preprocessor. Signed-off-by: Dushyant Behl <[email protected]> * fix fmt Signed-off-by: Dushyant Behl <[email protected]> * add unit tests Signed-off-by: Dushyant Behl <[email protected]> * Update advanced-data-preprocessing.md Signed-off-by: Dushyant Behl <[email protected]> --------- Signed-off-by: Dushyant Behl <[email protected]>
1 parent a89a4a3 commit f1fd130

File tree

8 files changed

+129
-2
lines changed

8 files changed

+129
-2
lines changed

docs/advanced-data-preprocessing.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ definitions:
6060
type: float
6161
builder:
6262
type: string
63+
rename_columns:
64+
type: object
65+
retain_columns:
66+
type: object
6367
data_paths:
6468
type: array
6569
items:
@@ -118,6 +122,8 @@ Users can create a data config file in any of YAML or JSON format they choose (w
118122
- `name` (optional, str): A unique identifier for the dataset.
119123
- `data_paths` (optional, list): A `list` of file paths or directories containing the dataset.
120124
- `builder` (optional, str): Specifies a [Hugging Face dataset builder](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/loading_methods#datasets.load_dataset.path), if applicable.
125+
- `rename_columns` (optional, dict[str:str]): Specifies a dictionary of columns to rename like `{"old_name": "new_name"}` at dataset load time. *Applied before `retain_columns` if both are specified*.
126+
- `retain_columns` (optional, list[str]): Specifies a list of columns to retain `["input_ids", "labels"]` every other column will be dropped at dataset load time. *Applied strictly after `rename_columns` if both are specified*.
121127
- `sampling` (optional, float): The sampling ratio (0.0 to 1.0) with which to sample a dataset in case of interleaving.
122128
- `data_handlers` (optional, list): A list of data handler configurations which preprocess the dataset.
123129

@@ -149,6 +155,10 @@ Not Supported:
149155
Currently there's no support for sampling under multiple data paths which are defined inside a dataset definition.
150156
All dataset paths that will be specified inside one dataset will be [concatenated](https://huggingface.co/docs/datasets/v3.2.0/en/process#concatenate) after loading them, while across datasets users can specify [mixing via sampling datasets](#data-mixing)
151157

158+
Probably something like this:
159+
160+
Additionally while loading the dataset, users can specify which columns to rename via `rename_columns` and which to retain via `retain_columns` arguments above.
161+
The order of application of these operations is *strictly rename followed by retain* so users should note that an old column name which is renamed will not be available in retain and hence should be careful while applying these operations. The code will throw a `ValueError` in case user specified a column requested to be renamed via rename argument in retain argument as well.
152162

153163
### How can users specify data handlers.
154164

docs/ept.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ datasets:
3232
data_paths:
3333
- "<path-to-the-jsonl-dataset>"
3434
data_handlers:
35-
- name: apply_custom_data_formatting
35+
- name: add_tokenizer_eos_token
3636
arguments:
3737
remove_columns: all
3838
batched: false
@@ -109,4 +109,4 @@ The code again would add `EOS_TOKEN` to the non tokenized data before using it a
109109

110110
### Additional Information
111111
This feature is supported post [v2.3.1](https://github.com/foundation-model-stack/fms-hf-tuning/releases/tag/v2.3.1) of this library.
112-
Post Last Updated On: 10-02-2025
112+
Post Last Updated On: 12-02-2025

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,6 @@
3737
DATA_CONFIG_DUPLICATE_COLUMNS = os.path.join(
3838
PREDEFINED_DATA_CONFIGS, "duplicate_columns.yaml"
3939
)
40+
DATA_CONFIG_RENAME_RETAIN_COLUMNS = os.path.join(
41+
PREDEFINED_DATA_CONFIGS, "rename_retain_columns.yaml"
42+
)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
dataprocessor:
2+
type: default
3+
datasets:
4+
- name: text_dataset_input_output_masking
5+
rename_columns:
6+
"input" : "instruction"
7+
"output" : "response"
8+
retain_columns:
9+
- "instruction"
10+
- "response"
11+
data_paths:
12+
- "FILE_PATH"
13+
data_handlers:
14+
- name: tokenize_and_apply_input_masking
15+
arguments:
16+
remove_columns: all
17+
batched: false
18+
fn_kwargs:
19+
input_field_name: instruction
20+
output_field_name: response

tests/data/test_data_preprocessing.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML,
3434
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
3535
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
36+
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
3637
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
3738
)
3839
from tests.artifacts.testdata import (
@@ -1365,3 +1366,57 @@ def test_process_dataset_configs_with_sampling_error(
13651366
(_, _, _, _, _, _) = process_dataargs(
13661367
data_args=data_args, tokenizer=tokenizer, train_args=TRAIN_ARGS
13671368
)
1369+
1370+
1371+
@pytest.mark.parametrize(
1372+
"datafile, rename, retain, final, datasetconfigname",
1373+
[
1374+
(
1375+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
1376+
{"input": "instruction", "output": "response"},
1377+
None,
1378+
["ID", "Label", "instruction", "response"],
1379+
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
1380+
),
1381+
(
1382+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
1383+
None,
1384+
["ID", "input", "output"],
1385+
["ID", "input", "output"],
1386+
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
1387+
),
1388+
(
1389+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
1390+
{"input": "instruction", "output": "response"},
1391+
["Label", "instruction", "response"],
1392+
["Label", "instruction", "response"],
1393+
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
1394+
),
1395+
],
1396+
)
1397+
def test_rename_and_retain_dataset_columns(
1398+
datafile, rename, retain, final, datasetconfigname
1399+
):
1400+
"""Test process_dataset_configs for expected output."""
1401+
dataprocessor_config = DataPreProcessorConfig()
1402+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
1403+
processor = DataPreProcessor(
1404+
processor_config=dataprocessor_config,
1405+
tokenizer=tokenizer,
1406+
)
1407+
datasetconfig = [
1408+
DataSetConfig(
1409+
name=datasetconfigname,
1410+
data_paths=[datafile],
1411+
rename_columns=rename,
1412+
retain_columns=retain,
1413+
)
1414+
]
1415+
train_dataset = processor.process_dataset_configs(dataset_configs=datasetconfig)
1416+
1417+
assert isinstance(train_dataset, Dataset)
1418+
assert set(train_dataset.column_names) == set(final)
1419+
1420+
with open(datafile, "r") as file:
1421+
data = json.load(file)
1422+
assert len(train_dataset) == len(data)

tests/test_sft_trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from tests.artifacts.predefined_data_configs import (
3939
DATA_CONFIG_DUPLICATE_COLUMNS,
4040
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
41+
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
4142
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
4243
)
4344
from tests.artifacts.testdata import (
@@ -837,6 +838,10 @@ def test_run_causallm_ft_pretokenized(dataset_path):
837838
],
838839
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
839840
),
841+
(
842+
[TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON],
843+
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
844+
),
840845
],
841846
)
842847
def test_run_causallm_ft_and_inference_with_multiple_dataset(

tuning/data/data_config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class DataSetConfig:
3636
data_paths: List[str]
3737
builder: Optional[str] = None # Referring to Hugging Face dataset builder
3838
sampling: Optional[float] = None
39+
rename_columns: Optional[Dict] = None
40+
retain_columns: Optional[List] = None
3941
data_handlers: Optional[List[DataHandlerConfig]] = None
4042

4143

@@ -100,6 +102,18 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
100102
0 <= ratio <= 1.0
101103
), f"sampling ratio: {ratio} should be float and in range [0.0,1.0]"
102104
c.sampling = ratio
105+
if "rename_columns" in kwargs and kwargs["rename_columns"] is not None:
106+
rename = kwargs["rename_columns"]
107+
assert isinstance(
108+
rename, dict
109+
), "rename_columns should be a dict with current_name:new_name"
110+
c.rename_columns = rename
111+
if "retain_columns" in kwargs and kwargs["retain_columns"] is not None:
112+
retain = kwargs["retain_columns"]
113+
assert isinstance(
114+
retain, list
115+
), "retain_columns should be a list[str] with names of columns to retain"
116+
c.retain_columns = retain
103117
if "data_handlers" in kwargs:
104118
c.data_handlers = []
105119
for handler in kwargs["data_handlers"]:

tuning/data/data_processors.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,26 @@ def _process_dataset_configs(
243243

244244
logger.info("Loaded raw dataset : %s", str(raw_dataset))
245245

246+
# Check if both are conflicting options before proceeding.
247+
if d.rename_columns and d.retain_columns:
248+
commmon = set(d.rename_columns.keys()) & set(d.retain_columns)
249+
if commmon:
250+
raise ValueError(
251+
f"You are trying to retain {str(commmon)} columns"
252+
" which will be renamed via rename operation."
253+
)
254+
255+
if d.rename_columns:
256+
logger.info("Renaming %s columns", str(d.rename_columns))
257+
raw_dataset = raw_dataset.rename_columns(
258+
column_mapping=d.rename_columns
259+
)
260+
logger.info("Done")
261+
if d.retain_columns:
262+
logger.info("Retaining %s columns", str(d.retain_columns))
263+
raw_dataset = raw_dataset.select_columns(column_names=d.retain_columns)
264+
logger.info("Done")
265+
246266
raw_datasets = DatasetDict()
247267

248268
# Assume all is train split

0 commit comments

Comments
 (0)