File tree Expand file tree Collapse file tree 2 files changed +13
-12
lines changed
src/axolotl/core/trainers Expand file tree Collapse file tree 2 files changed +13
-12
lines changed Original file line number Diff line number Diff line change @@ -49,7 +49,7 @@ def parse_requirements(extras_require_map):
4949 try :
5050 torch_version = version ("torch" )
5151 except PackageNotFoundError :
52- torch_version = "2.6 .0" # default to torch 2.6
52+ torch_version = "2.8 .0" # default to torch 2.8.0
5353 _install_requires .append (f"torch=={ torch_version } " )
5454
5555 version_match = re .match (r"^(\d+)\.(\d+)(?:\.(\d+))?" , torch_version )
Original file line number Diff line number Diff line change @@ -225,17 +225,6 @@ def _get_dataloader(
225225
226226 data_collator = self .data_collator if is_training else self .eval_data_collator
227227
228- if dataset .column_names and "length" in dataset .column_names :
229- dataset = dataset .remove_columns (["length" ])
230- if (
231- dataset .column_names
232- and "position_ids" in dataset .column_names
233- and "attention_mask" in dataset .column_names
234- and self .args .sample_packing
235- and self .args .sample_packing_drop_attention_mask
236- ):
237- dataset = dataset .remove_columns (["attention_mask" ])
238-
239228 if isinstance (dataset , datasets .Dataset ):
240229 if is_training :
241230 if not self .args .sample_packing or self .args .pretraining :
@@ -294,6 +283,18 @@ def _get_dataloader(
294283 ):
295284 self .accelerator .even_batches = False
296285
286+ if dataset .column_names and "length" in dataset .column_names :
287+ dataset = dataset .remove_columns (["length" ])
288+
289+ if (
290+ dataset .column_names
291+ and "position_ids" in dataset .column_names
292+ and "attention_mask" in dataset .column_names
293+ and self .args .sample_packing
294+ and self .args .sample_packing_drop_attention_mask
295+ ):
296+ dataset = dataset .remove_columns (["attention_mask" ])
297+
297298 dataloader = DataLoader (dataset , ** dataloader_params )
298299
299300 # Accelerator.free_memory() will destroy the references, so
You can’t perform that action at this time.
0 commit comments