Skip to content

Commit fb3ace8

Browse files
feat: adding eos token to be made a flag so we don't force it on every handler (#467)
* Rethink add_eos by making it a flag so we don't force it on every handler. Signed-off-by: Dushyant Behl <[email protected]> * Merge with main before adding unit tests Signed-off-by: Abhishek <[email protected]> * Added documentation and test case Signed-off-by: Abhishek <[email protected]> * Added documentation and test case Signed-off-by: Abhishek <[email protected]> * Added documentation and test case Signed-off-by: Abhishek <[email protected]> * Updated test case Signed-off-by: Abhishek <[email protected]> --------- Signed-off-by: Dushyant Behl <[email protected]> Signed-off-by: Abhishek <[email protected]> Co-authored-by: Abhishek <[email protected]>
1 parent 2f033c7 commit fb3ace8

File tree

6 files changed

+122
-8
lines changed

6 files changed

+122
-8
lines changed

docs/advanced-data-preprocessing.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,14 +214,17 @@ Users can also pass any number of `kwargs` arguments required for each data hand
214214

215215
#### Preexisting data handlers
216216
This library currently supports the following [preexisting data handlers](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py#L156):
217-
- `tokenize_and_apply_input_masking`:
218-
Tokenizes input text and applies masking to the labels for causal language modeling tasks, good for input/output datasets.
219217
- `add_tokenizer_eos_token`:
220218
Appends the tokenizer's EOS token to a specified dataset field.
221219
- `apply_custom_data_formatting_template`:
222220
Applies a custom template (e.g., Alpaca style) to format dataset elements.
221+
By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tests/artifacts/predefined_data_configs/apply_custom_template.yaml)
222+
- `tokenize_and_apply_input_masking`:
223+
Tokenizes input text and applies masking to the labels for causal language modeling tasks, good for input/output datasets.
224+
By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml)
223225
- `apply_custom_jinja_template`:
224226
Applies a custom jinja template (e.g., Alpaca style) to format dataset elements.
227+
By default this handler adds `EOS_TOKEN` which can be disabled by a handler argument, [see](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml)
225228
- `apply_tokenizer_chat_template`:
226229
Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates.
227230
- `duplicate_columns`:

tests/artifacts/predefined_data_configs/apply_custom_jinja_template.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ datasets:
1111
batched: false
1212
fn_kwargs:
1313
dataset_text_field: "dataset_text_field"
14-
template: "dataset_template"
14+
template: "dataset_template"
15+
add_eos_token: true

tests/artifacts/predefined_data_configs/apply_custom_template.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ datasets:
1111
batched: false
1212
fn_kwargs:
1313
dataset_text_field: "dataset_text_field"
14-
template: "dataset_template"
14+
template: "dataset_template"
15+
add_eos_token: true

tests/artifacts/predefined_data_configs/tokenize_and_apply_input_masking.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ datasets:
1111
batched: false
1212
fn_kwargs:
1313
input_field_name: input
14-
output_field_name: output
14+
output_field_name: output
15+
add_eos_token: true

tests/data/test_data_preprocessing.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,102 @@ def test_process_dataconfig_file(data_config_path, data_path):
762762
assert formatted_dataset_field in set(train_set.column_names)
763763

764764

765+
@pytest.mark.parametrize(
766+
"data_config_path, data_path, add_eos_token",
767+
[
768+
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON, True),
769+
(DATA_CONFIG_APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON, False),
770+
(
771+
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML,
772+
TWITTER_COMPLAINTS_DATA_JSON,
773+
True,
774+
),
775+
(
776+
DATA_CONFIG_APPLY_CUSTOM_JINJA_TEMPLATE_YAML,
777+
TWITTER_COMPLAINTS_DATA_JSON,
778+
False,
779+
),
780+
(
781+
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
782+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
783+
True,
784+
),
785+
(
786+
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
787+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
788+
False,
789+
),
790+
],
791+
)
792+
def test_process_datahandler_eos_token(data_config_path, data_path, add_eos_token):
793+
"""Ensure that the data handlers correctly apply add_eos_token flag to append/remove eos_token."""
794+
with open(data_config_path, "r") as f:
795+
yaml_content = yaml.safe_load(f)
796+
yaml_content["datasets"][0]["data_paths"][0] = data_path
797+
datasets_name = yaml_content["datasets"][0]["name"]
798+
799+
# Modify input_field_name and output_field_name according to dataset
800+
if datasets_name == "text_dataset_input_output_masking":
801+
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][
802+
"input_field_name"
803+
] = "input"
804+
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][
805+
"output_field_name"
806+
] = "output"
807+
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][
808+
"add_eos_token"
809+
] = add_eos_token
810+
811+
# Modify dataset_text_field and template according to dataset
812+
formatted_dataset_field = "formatted_data_field"
813+
if datasets_name in (
814+
"apply_custom_data_template",
815+
"apply_custom_data_jinja_template",
816+
):
817+
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
818+
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][
819+
"dataset_text_field"
820+
] = formatted_dataset_field
821+
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][
822+
"template"
823+
] = template
824+
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"][
825+
"add_eos_token"
826+
] = add_eos_token
827+
828+
with tempfile.NamedTemporaryFile(
829+
"w", delete=False, suffix=".yaml"
830+
) as temp_yaml_file:
831+
yaml.dump(yaml_content, temp_yaml_file)
832+
temp_yaml_file_path = temp_yaml_file.name
833+
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)
834+
835+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
836+
tokenizer.add_special_tokens({"eos_token": "</s>"})
837+
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)
838+
assert isinstance(train_set, Dataset)
839+
if datasets_name == "text_dataset_input_output_masking":
840+
column_names = set(["input_ids", "attention_mask", "labels"])
841+
assert set(train_set.column_names) == column_names
842+
assert (
843+
train_set[0]["input_ids"][-1] == tokenizer.eos_token_id
844+
if add_eos_token
845+
else train_set[0]["input_ids"][-1] != tokenizer.eos_token_id
846+
)
847+
elif datasets_name == "pretokenized_dataset":
848+
assert set(["input_ids", "labels"]).issubset(set(train_set.column_names))
849+
elif datasets_name in (
850+
"apply_custom_data_template",
851+
"apply_custom_data_jinja_template",
852+
):
853+
assert formatted_dataset_field in set(train_set.column_names)
854+
assert (
855+
train_set[0][formatted_dataset_field].endswith(tokenizer.eos_token)
856+
if add_eos_token
857+
else not train_set[0][formatted_dataset_field].endswith(tokenizer.eos_token)
858+
)
859+
860+
765861
@pytest.mark.parametrize(
766862
"data_config_path, data_path_list",
767863
[

tuning/data/data_handlers.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def tokenize_and_apply_input_masking(
5858
column_names: List[str],
5959
input_field_name: str,
6060
output_field_name: str,
61+
add_eos_token: bool = True,
6162
**kwargs,
6263
):
6364
"""Function (data handler) to tokenize and apply instruction masking on dataset
@@ -68,6 +69,7 @@ def tokenize_and_apply_input_masking(
6869
column_names: Name of all the columns in the dataset.
6970
input_field_name: Name of the input (instruction) field in dataset
7071
output_field_name: Name of the output field in dataset
72+
add_eos_token: should add tokenizer.eos_token to text or not, defaults to True
7173
**kwargs: Any additional args passed to the handler
7274
Returns:
7375
Formatted Dataset element with input_ids, labels and attention_mask columns
@@ -83,7 +85,11 @@ def tokenize_and_apply_input_masking(
8385
input_text = element[input_field_name]
8486
output_text = element[output_field_name]
8587

86-
combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token)
88+
eos_token = ""
89+
if add_eos_token:
90+
eos_token = tokenizer.eos_token
91+
92+
combined = combine_sequence(input_text, output_text, eos_token=eos_token)
8793

8894
tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
8995

@@ -131,6 +137,7 @@ def apply_custom_data_formatting_template(
131137
tokenizer: AutoTokenizer,
132138
dataset_text_field: str,
133139
template: str,
140+
add_eos_token: bool = True,
134141
**kwargs,
135142
):
136143
"""Function (data handler) to format datasets with Alpaca style / other templates.
@@ -142,12 +149,14 @@ def apply_custom_data_formatting_template(
142149
dataset_text_field: Text column name of the dataset where formatted text is saved.
143150
template: Template to format data with. Features of Dataset
144151
should be referred to by {{key}}
152+
add_eos_token: should add tokenizer.eos_token to text or not, defaults to True
145153
Returns:
146154
Formatted Dataset element by formatting dataset with template+tokenizer.EOS_TOKEN
147155
Saves the result to dataset_text_field argument.
148156
"""
149157

150-
template += tokenizer.eos_token
158+
if add_eos_token:
159+
template += tokenizer.eos_token
151160

152161
def replace_text(match_obj):
153162
captured_groups = match_obj.groups()
@@ -174,6 +183,7 @@ def apply_custom_jinja_template(
174183
tokenizer: AutoTokenizer,
175184
dataset_text_field: str,
176185
template: str,
186+
add_eos_token: bool = True,
177187
**kwargs,
178188
):
179189
"""Function (data handler) to format datasets with jinja templates.
@@ -185,12 +195,14 @@ def apply_custom_jinja_template(
185195
dataset_text_field: formatted_dataset_field.
186196
template: Template to format data with. Features of Dataset
187197
should be referred to by {{key}}.
198+
add_eos_token: should add tokenizer.eos_token to text or not, defaults to True
188199
Returns:
189200
Formatted HF Dataset element by formatting dataset with provided jinja template
190201
Saves the result to dataset_text_field argument.
191202
"""
203+
if add_eos_token:
204+
template += tokenizer.eos_token
192205

193-
template += tokenizer.eos_token
194206
template = process_jinja_placeholders(template)
195207
env = SandboxedEnvironment(undefined=StrictUndefined)
196208

0 commit comments

Comments
 (0)