@@ -58,7 +58,7 @@ def tokenize_and_apply_input_masking(
5858 column_names : List [str ],
5959 input_field_name : str ,
6060 output_field_name : str ,
61- ** tokenizer_kwargs ,
61+ ** kwargs ,
6262):
6363 """Function (data handler) to tokenize and apply instruction masking on dataset
6464 Expects to be run as a HF Map API function.
@@ -68,7 +68,7 @@ def tokenize_and_apply_input_masking(
6868 column_names: Name of all the columns in the dataset.
6969 input_field_name: Name of the input (instruction) field in dataset
7070 output_field_name: Name of the output field in dataset
71- **tokenizer_kwargs : Any additional kwargs to be passed to tokenizer
71+ **kwargs : Any additional args passed to the handler
7272 Returns:
7373 Formatted Dataset element with input_ids, labels and attention_mask columns
7474 """
@@ -85,11 +85,10 @@ def tokenize_and_apply_input_masking(
8585
8686 combined = combine_sequence (input_text , output_text , eos_token = tokenizer .eos_token )
8787
88- fn_kwargs = tokenizer_kwargs .get ("fn_kwargs" , {})
89- tokenizer_inner_kwargs = fn_kwargs .get ("tokenizer_kwargs" , {})
88+ tokenizer_kwargs = kwargs .get ("tokenizer_kwargs" , {})
9089
91- tokenized_comb_seqs = tokenizer (combined , ** tokenizer_inner_kwargs )
92- tokenized_input = tokenizer (input_text , ** tokenizer_inner_kwargs )
90+ tokenized_comb_seqs = tokenizer (combined , ** tokenizer_kwargs )
91+ tokenized_input = tokenizer (input_text , ** tokenizer_kwargs )
9392
9493 masked_labels = [- 100 ] * len (
9594 tokenized_input .input_ids
0 commit comments