Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/axolotl/core/builders/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def build_collator(
):
if training_args.pretraining:
if (
self.cfg.pretraining_sample_concatenation is False
not self.cfg.pretraining_sample_concatenation
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
Expand Down
14 changes: 14 additions & 0 deletions src/axolotl/core/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,20 @@ def _get_dataloader(
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)

if (
self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this only be scaling by the total dp size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accelerator handles the actual dp sharding, so I don't think so? I'll test.

else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes

Comment on lines +276 to +288
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Critical: Reintroducing batch_size alongside a BatchSampler causes DataLoader errors

This block can set or mutate dataloader_params["batch_size"] even when a BatchSampler is already configured (e.g., eval with sample packing returns a MultipackBatchSampler). PyTorch DataLoader forbids passing both batch_sampler and batch_size. In the eval pretraining + sample_packing + eval_sample_packing=True path, we delete batch_size earlier then re-add it here, leading to a runtime error and/or undefined behavior.

Guard against the presence of batch_sampler before touching batch_size:

         if (
             self.args.accelerator_config is not None
             and self.args.accelerator_config.split_batches
             and self.args.accelerator_config.dispatch_batches
         ):
-            if self.args.sample_packing and self.args.pretraining:
-                if not self.args.eval_sample_packing and not is_training:
-                    dataloader_params["batch_size"] *= self.accelerator.num_processes
-                else:
-                    dataloader_params["batch_size"] = self.accelerator.num_processes
-            elif not self.args.sample_packing and self.args.pretraining:
-                dataloader_params["batch_size"] *= self.accelerator.num_processes
+            # Avoid conflicting with BatchSampler: DataLoader forbids both batch_size and batch_sampler.
+            if "batch_sampler" not in dataloader_params:
+                if self.args.sample_packing and self.args.pretraining:
+                    if not self.args.eval_sample_packing and not is_training:
+                        dataloader_params["batch_size"] *= self.accelerator.num_processes
+                    else:
+                        dataloader_params["batch_size"] = self.accelerator.num_processes
+                elif not self.args.sample_packing and self.args.pretraining:
+                    dataloader_params["batch_size"] *= self.accelerator.num_processes

Follow-up: If you actually need different behavior for the multipack case under split/dispatch, the adjustment must be applied inside the MultipackBatchSampler construction (e.g., its batch_size or batch_max_len), not by reintroducing batch_size here.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (
self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes
if (
self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
# Avoid conflicting with BatchSampler: DataLoader forbids both batch_size and batch_sampler.
if "batch_sampler" not in dataloader_params:
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes

if self.args.sample_packing and (
(is_training and not self.args.pretraining)
or (not is_training and self.args.eval_sample_packing is not False)
Expand Down
Empty file.