Skip to content

Commit 3750fdc

Browse files
qywuwinglian
andauthored
Fix trainer dataloader slow loading issue (#3219)
* Fix trainer dataloader handling in src/axolotl/core/trainers/base.py * update comment to reflect torch version --------- Co-authored-by: Wing Lian <[email protected]>
1 parent 613bcf9 commit 3750fdc

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def parse_requirements(extras_require_map):
4949
try:
5050
torch_version = version("torch")
5151
except PackageNotFoundError:
52-
torch_version = "2.6.0" # default to torch 2.6
52+
torch_version = "2.8.0" # default to torch 2.8.0
5353
_install_requires.append(f"torch=={torch_version}")
5454

5555
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)

src/axolotl/core/trainers/base.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -225,17 +225,6 @@ def _get_dataloader(
225225

226226
data_collator = self.data_collator if is_training else self.eval_data_collator
227227

228-
if dataset.column_names and "length" in dataset.column_names:
229-
dataset = dataset.remove_columns(["length"])
230-
if (
231-
dataset.column_names
232-
and "position_ids" in dataset.column_names
233-
and "attention_mask" in dataset.column_names
234-
and self.args.sample_packing
235-
and self.args.sample_packing_drop_attention_mask
236-
):
237-
dataset = dataset.remove_columns(["attention_mask"])
238-
239228
if isinstance(dataset, datasets.Dataset):
240229
if is_training:
241230
if not self.args.sample_packing or self.args.pretraining:
@@ -294,6 +283,18 @@ def _get_dataloader(
294283
):
295284
self.accelerator.even_batches = False
296285

286+
if dataset.column_names and "length" in dataset.column_names:
287+
dataset = dataset.remove_columns(["length"])
288+
289+
if (
290+
dataset.column_names
291+
and "position_ids" in dataset.column_names
292+
and "attention_mask" in dataset.column_names
293+
and self.args.sample_packing
294+
and self.args.sample_packing_drop_attention_mask
295+
):
296+
dataset = dataset.remove_columns(["attention_mask"])
297+
297298
dataloader = DataLoader(dataset, **dataloader_params)
298299

299300
# Accelerator.free_memory() will destroy the references, so

0 commit comments

Comments
 (0)