|
25 | 25 | from tuning.config import configs |
26 | 26 |
|
27 | 27 |
|
28 | | -def validate_data_args( |
| 28 | +def validate_data_args(data_args: configs.DataArguments, packing: bool): |
| 29 | + |
| 30 | + assert isinstance( |
| 31 | + data_args.training_data_path, str |
| 32 | + ), "Training data path has to be set and str" |
| 33 | + |
| 34 | + # Dataset containing single sequence needs a response template for masking |
| 35 | + if data_args.response_template is None and data_args.dataset_text_field is not None: |
| 36 | + if packing is False: |
| 37 | + raise ValueError( |
| 38 | + "Since dataset_text_field is provided and packing is disabled, \ |
| 39 | + needs a corresponding response template for masking" |
| 40 | + ) |
| 41 | + |
| 42 | + # Currently if packing is false, we require a response_template. This may change in future. |
| 43 | + if packing is False: |
| 44 | + if data_args.response_template is None: |
| 45 | + raise ValueError( |
| 46 | + "Response template is None, needs to be set for training \ |
| 47 | + with packing disabled." |
| 48 | + ) |
| 49 | + |
| 50 | + if data_args.response_template: |
| 51 | + # To use Response template, pass datasets with single sequence instances \ |
| 52 | + # or a formatter template to create single sequence on the fly. |
| 53 | + if not (data_args.dataset_text_field or data_args.data_formatter_template): |
| 54 | + raise ValueError( |
| 55 | + "dataset_text_field and data_formatter_template are None. \ |
| 56 | + One of them needs to be set to use response_template" |
| 57 | + ) |
| 58 | + # Only one of dataset_text_field or data_formatter_template should be set. |
| 59 | + if data_args.dataset_text_field and data_args.data_formatter_template: |
| 60 | + raise ValueError( |
| 61 | + "dataset_text_field and data_formatter_template are both set,\ |
| 62 | + but are mutually exclusive options" |
| 63 | + ) |
| 64 | + # TODO(s) In future seupport two more formats: |
| 65 | + # 1. Allow no response template, and JSON with input/output fields and mask input |
| 66 | + |
| 67 | + # 2. Allow pretokenized Dataset besides JSON. |
| 68 | + |
| 69 | + |
| 70 | +def get_data_collator( |
| 71 | + packing: bool, |
| 72 | + response_template: Optional[str], |
| 73 | + tokenizer: AutoTokenizer, |
| 74 | +) -> Callable: |
| 75 | + """Create and return the the appropriate collator type based on the configuration for packing, |
| 76 | + response_template, and dataset_text_field. |
| 77 | +
|
| 78 | + Args: |
| 79 | + packing: bool |
| 80 | + Whether or not we should apply packing or not. |
| 81 | + response_template: Optional[str] |
| 82 | + Response template to be used for formatting by TRL. |
| 83 | + tokenizer: AutoTokenizer |
| 84 | + Loaded tokenizer object to be used by the collator. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + Callable |
| 88 | + Callable collator to be leveraged by the trainer. |
| 89 | + """ |
| 90 | + if not packing: |
| 91 | + # TODO: near term - how response template ids are parsed out needs to be cleaned. |
| 92 | + # The [2:] here applies if response template has \n prefix, it is needed to strip \n, |
| 93 | + # otherwise template is not found. We will create issue to clean this out after we discuss |
| 94 | + # data formats and collators we will support. |
| 95 | + if response_template: |
| 96 | + response_template_ids = tokenizer.encode( |
| 97 | + response_template, add_special_tokens=False |
| 98 | + )[2:] |
| 99 | + return DataCollatorForCompletionOnlyLM( |
| 100 | + response_template=response_template_ids, |
| 101 | + tokenizer=tokenizer, |
| 102 | + ignore_index=configs.IGNORE_INDEX, |
| 103 | + ) |
| 104 | + # TO DO with future changes, |
| 105 | + # 1. Support no packing and seq2seq colator without response template |
| 106 | + # # if dataset_text_field is None and response_template is None: |
| 107 | + # # Use the seq2seq data collator; |
| 108 | + # # Note that this automatically pads labels with -100 |
| 109 | + # return DataCollatorForSeq2Seq( |
| 110 | + # tokenizer=tokenizer, padding=True, max_length=max_sequence_length |
| 111 | + # ) |
| 112 | + # 2. add anything needed for preprocessed input |
| 113 | + |
| 114 | + |
| 115 | +################################################################################### |
| 116 | +### The functions below are not yet used. Iterative development towards new features |
| 117 | + |
| 118 | + |
| 119 | +def get_data_collator_temp( |
| 120 | + packing: bool, |
29 | 121 | dataset_text_field: Optional[str], |
30 | 122 | response_template: Optional[str], |
31 | | -): |
32 | | - # Dataset containing single sequence needs a single sequence and a response template |
33 | | - if dataset_text_field is None and response_template is not None: |
34 | | - raise ValueError( |
35 | | - "Needs a corresponding dataset_text_feld \ |
36 | | - in which to look for response_template" |
37 | | - ) |
38 | | - if response_template is None and dataset_text_field is not None: |
39 | | - raise ValueError( |
40 | | - "Since dataset_text_field is provided, \ |
41 | | - needs a corresponding response template for masking" |
42 | | - ) |
43 | | - # Dataset containing JSON with fields and a formatter template |
44 | | - # TO DO load JSON and check input/output field is present |
| 123 | + max_sequence_length: int, |
| 124 | + tokenizer: AutoTokenizer, |
| 125 | +) -> Callable: |
| 126 | + """Create and return the the appropriate collator type based on the configuration for packing, |
| 127 | + response_template, and dataset_text_field. |
45 | 128 |
|
46 | | - # in future : pretokenized Dataset may be added. |
| 129 | + Args: |
| 130 | + packing: bool |
| 131 | + Whether or not we should apply packing or not. |
| 132 | + dataset_text_field: Optional[str] |
| 133 | + Dataset text field fto be used for formatting by TRL. |
| 134 | + response_template: Optional[str] |
| 135 | + Response template to be used for formatting by TRL. |
| 136 | + max_sequence_length: int |
| 137 | + Max sequence length to be used for sequence tokenization. |
| 138 | + tokenizer: AutoTokenizer |
| 139 | + Loaded tokenizer object to be used by the collator. |
| 140 | +
|
| 141 | + Returns: |
| 142 | + Callable |
| 143 | + Callable collator to be leveraged by the trainer. |
| 144 | + """ |
| 145 | + if not packing: |
| 146 | + if dataset_text_field is None and response_template is None: |
| 147 | + # Use the seq2seq data collator; note that this automatically pads labels with -100 |
| 148 | + return DataCollatorForSeq2Seq( |
| 149 | + tokenizer=tokenizer, padding=True, max_length=max_sequence_length |
| 150 | + ) |
| 151 | + # TODO: near term - how response template ids are parsed out needs to be cleaned. |
| 152 | + # The [2:] here applies if response template has \n prefix, it is needed to strip \n, |
| 153 | + # otherwise template is not found. We will create issue to clean this out after we discuss |
| 154 | + # data formats and collators we will support. |
| 155 | + response_template_ids = tokenizer.encode( |
| 156 | + response_template, add_special_tokens=False |
| 157 | + )[2:] |
| 158 | + return DataCollatorForCompletionOnlyLM( |
| 159 | + response_template=response_template_ids, |
| 160 | + tokenizer=tokenizer, |
| 161 | + ignore_index=configs.IGNORE_INDEX, |
| 162 | + ) |
47 | 163 |
|
48 | 164 |
|
49 | 165 | def get_data_trainer_kwargs( |
@@ -82,7 +198,7 @@ def get_data_trainer_kwargs( |
82 | 198 | Dict[str, Any] |
83 | 199 | Data related kwargs to be used by the SFT Trainer. |
84 | 200 | """ |
85 | | - data_collator = get_data_collator( |
| 201 | + data_collator = get_data_collator_temp( |
86 | 202 | packing, dataset_text_field, response_template, max_sequence_length, tokenizer |
87 | 203 | ) |
88 | 204 | eval_dataset = None |
@@ -122,52 +238,6 @@ def get_data_trainer_kwargs( |
122 | 238 | return data_kwargs |
123 | 239 |
|
124 | 240 |
|
125 | | -def get_data_collator( |
126 | | - packing: bool, |
127 | | - dataset_text_field: Optional[str], |
128 | | - response_template: Optional[str], |
129 | | - max_sequence_length: int, |
130 | | - tokenizer: AutoTokenizer, |
131 | | -) -> Callable: |
132 | | - """Create and return the the appropriate collator type based on the configuration for packing, |
133 | | - response_template, and dataset_text_field. |
134 | | -
|
135 | | - Args: |
136 | | - packing: bool |
137 | | - Whether or not we should apply packing or not. |
138 | | - dataset_text_field: Optional[str] |
139 | | - Dataset text field fto be used for formatting by TRL. |
140 | | - response_template: Optional[str] |
141 | | - Response template to be used for formatting by TRL. |
142 | | - max_sequence_length: int |
143 | | - Max sequence length to be used for sequence tokenization. |
144 | | - tokenizer: AutoTokenizer |
145 | | - Loaded tokenizer object to be used by the collator. |
146 | | -
|
147 | | - Returns: |
148 | | - Callable |
149 | | - Callable collator to be leveraged by the trainer. |
150 | | - """ |
151 | | - if not packing: |
152 | | - if dataset_text_field is None and response_template is None: |
153 | | - # Use the seq2seq data collator; note that this automatically pads labels with -100 |
154 | | - return DataCollatorForSeq2Seq( |
155 | | - tokenizer=tokenizer, padding=True, max_length=max_sequence_length |
156 | | - ) |
157 | | - # TODO: near term - how response template ids are parsed out needs to be cleaned. |
158 | | - # The [2:] here applies if response template has \n prefix, it is needed to strip \n, |
159 | | - # otherwise template is not found. We will create issue to clean this out after we discuss |
160 | | - # data formats and collators we will support. |
161 | | - response_template_ids = tokenizer.encode( |
162 | | - response_template, add_special_tokens=False |
163 | | - )[2:] |
164 | | - return DataCollatorForCompletionOnlyLM( |
165 | | - response_template=response_template_ids, |
166 | | - tokenizer=tokenizer, |
167 | | - ignore_index=configs.IGNORE_INDEX, |
168 | | - ) |
169 | | - |
170 | | - |
171 | 241 | def get_formatted_dataset( |
172 | 242 | data_path: str, dataset_text_field: str, tokenizer: AutoTokenizer |
173 | 243 | ) -> Dataset: |
|
0 commit comments