@@ -55,44 +55,55 @@ def get_data_collator(
5555 Callable collator to be leveraged by the trainer.
5656 """
5757
58+ if packing :
59+ if is_traindata_tokenized :
60+ # packing with tokenized dataset requires seq2seq collator.
61+ return DataCollatorForSeq2Seq (
62+ tokenizer = tokenizer , padding = False , max_length = max_seq_length
63+ )
64+
65+ # packing for non tokenized dataset doesn't require a collator with SFTrainer.
66+ return None
67+
68+ # TODO: near term - how response template ids are parsed out needs to be cleaned.
69+ # The [2:] here applies if response template has \n prefix, it is needed to strip \n,
70+ # otherwise template is not found. We will create issue to clean this out after we discuss
71+ # data formats and collators we will support.
5872 if response_template and instruction_template :
73+ # Pass both instruction and response template for chat style training.
5974 return DataCollatorForCompletionOnlyLM (
6075 response_template = response_template ,
6176 instruction_template = instruction_template ,
6277 tokenizer = tokenizer ,
6378 ignore_index = configs .IGNORE_INDEX ,
6479 )
6580
66- if not packing :
67- # TODO: near term - how response template ids are parsed out needs to be cleaned.
68- # The [2:] here applies if response template has \n prefix, it is needed to strip \n,
69- # otherwise template is not found. We will create issue to clean this out after we discuss
70- # data formats and collators we will support.
71- if response_template :
72- response_template_ids = tokenizer .encode (
73- response_template , add_special_tokens = False
74- )[2 :]
75- return DataCollatorForCompletionOnlyLM (
76- response_template = response_template_ids ,
77- tokenizer = tokenizer ,
78- ignore_index = configs .IGNORE_INDEX ,
79- )
81+ if response_template :
82+ response_template_ids = tokenizer .encode (
83+ response_template , add_special_tokens = False
84+ )[2 :]
85+ return DataCollatorForCompletionOnlyLM (
86+ response_template = response_template_ids ,
87+ tokenizer = tokenizer ,
88+ ignore_index = configs .IGNORE_INDEX ,
89+ )
8090
81- if is_padding_free :
82- # when packing is false but padding_free is used and
83- # no response template is used then its a pretrained scenario.
84- # Current plugin in fms-acceleration is compatible with
85- # `DataCollatorForSeq2Seq` collator hence we use this.
86- return DataCollatorForSeq2Seq (
87- tokenizer = tokenizer , padding = False , max_length = max_seq_length
88- )
91+ if is_padding_free :
92+ # when packing is false but padding_free is used and
93+ # no response template is used then its a pretrained scenario.
94+ # Current plugin in fms-acceleration is compatible with
95+ # `DataCollatorForSeq2Seq` collator hence we use this.
96+ return DataCollatorForSeq2Seq (
97+ tokenizer = tokenizer , padding = False , max_length = max_seq_length
98+ )
8999
100+ if is_traindata_tokenized :
90101 # Note that this automatically pads labels with -100
91102 # TODO check if this is sufficient for preprocessed
92- if is_traindata_tokenized :
93- return DataCollatorForSeq2Seq (
94- tokenizer = tokenizer , padding = True , max_length = max_seq_length
95- )
96- raise ValueError (
97- "Could not pick a data collator. Please refer to supported data formats"
103+ return DataCollatorForSeq2Seq (
104+ tokenizer = tokenizer , padding = True , max_length = max_seq_length
98105 )
106+
107+ raise ValueError (
108+ "Could not pick a data collator. Please refer to supported data formats"
109+ )
0 commit comments