Skip to content

Commit a89a4a3

Browse files
fix tokenize_and_apply_input_masking kwargs (#465)
Signed-off-by: Abhishek <[email protected]>
1 parent 381fdd5 commit a89a4a3

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

tests/data/test_data_preprocessing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1173,9 +1173,10 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
11731173
def test_process_dataargs(data_args, is_padding_free):
11741174
"""Ensure that the train/eval data are properly formatted based on the data args / text field"""
11751175
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
1176+
max_seq_length = 5
11761177
TRAIN_ARGS = configs.TrainingArguments(
11771178
packing=False,
1178-
max_seq_length=1024,
1179+
max_seq_length=max_seq_length,
11791180
output_dir="tmp", # Not needed but positional
11801181
)
11811182
(train_set, eval_set, dataset_text_field, _, _, _) = process_dataargs(
@@ -1187,6 +1188,7 @@ def test_process_dataargs(data_args, is_padding_free):
11871188
column_names = set(["input_ids", "attention_mask", "labels"])
11881189
assert set(eval_set.column_names) == column_names
11891190
assert set(train_set.column_names) == column_names
1191+
assert len(train_set[0]["input_ids"]) == max_seq_length
11901192
else:
11911193
assert dataset_text_field in train_set.column_names
11921194
assert dataset_text_field in eval_set.column_names

tuning/data/data_handlers.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def tokenize_and_apply_input_masking(
5858
column_names: List[str],
5959
input_field_name: str,
6060
output_field_name: str,
61-
**tokenizer_kwargs,
61+
**kwargs,
6262
):
6363
"""Function (data handler) to tokenize and apply instruction masking on dataset
6464
Expects to be run as a HF Map API function.
@@ -68,7 +68,7 @@ def tokenize_and_apply_input_masking(
6868
column_names: Name of all the columns in the dataset.
6969
input_field_name: Name of the input (instruction) field in dataset
7070
output_field_name: Name of the output field in dataset
71-
**tokenizer_kwargs: Any additional kwargs to be passed to tokenizer
71+
**kwargs: Any additional args passed to the handler
7272
Returns:
7373
Formatted Dataset element with input_ids, labels and attention_mask columns
7474
"""
@@ -85,11 +85,10 @@ def tokenize_and_apply_input_masking(
8585

8686
combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token)
8787

88-
fn_kwargs = tokenizer_kwargs.get("fn_kwargs", {})
89-
tokenizer_inner_kwargs = fn_kwargs.get("tokenizer_kwargs", {})
88+
tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
9089

91-
tokenized_comb_seqs = tokenizer(combined, **tokenizer_inner_kwargs)
92-
tokenized_input = tokenizer(input_text, **tokenizer_inner_kwargs)
90+
tokenized_comb_seqs = tokenizer(combined, **tokenizer_kwargs)
91+
tokenized_input = tokenizer(input_text, **tokenizer_kwargs)
9392

9493
masked_labels = [-100] * len(
9594
tokenized_input.input_ids

0 commit comments

Comments
 (0)