Skip to content

Conversation

@winglian
Copy link
Collaborator

@winglian winglian commented Aug 17, 2025

The batch size handling when using split batches isn't quite right. This sets the correct batch size based on how we handle batches when pretraining and needing to scale it properly for dispatch_batches/split_batches

Summary by CodeRabbit

  • Bug Fixes
    • Corrected batch size computation during pretraining when both split and dispatch batching are enabled, with sensible handling across training and evaluation and with/without sample packing. Improves stability and throughput by avoiding unintended per-process sizing.
    • Made pretraining sample concatenation handling more robust to falsy configuration values, ensuring the appropriate collator is selected during pretraining. Prevents misconfiguration that could impact batching and data processing behavior.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 17, 2025

📝 Walkthrough

Walkthrough

Updates pretraining handling in the causal collator condition and introduces new batch_size adjustment logic in AxolotlTrainer._get_dataloader for specific accelerator configurations during pretraining, with branching based on sample packing and eval/training mode.

Changes

Cohort / File(s) Summary
Collator pretraining condition
src/axolotl/core/builders/causal.py
In build_collator, replaced explicit check self.cfg.pretraining_sample_concatenation is False with not self.cfg.pretraining_sample_concatenation, broadening falsy handling for selecting DataCollatorForSeq2Seq when training_args.pretraining is true and micro_batch_size > 1.
Trainer dataloader batch sizing
src/axolotl/core/trainers/base.py
In AxolotlTrainer._get_dataloader, added conditional adjustment of dataloader_params["batch_size"] when accelerator_config.split_batches and dispatch_batches are both true. For pretraining: if sample_packing and (eval without eval_sample_packing), multiply by accelerator.num_processes; else set to accelerator.num_processes. If not sample_packing, multiply by accelerator.num_processes. No changes otherwise.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch split-batches-sizes

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@winglian winglian marked this pull request as ready for review August 18, 2025 12:42
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
src/axolotl/core/builders/causal.py (1)

425-433: Truthiness change may alter pretraining collator selection for None/unspecified values

Switching from is False to not self.cfg.pretraining_sample_concatenation broadens the condition to treat any falsy value (e.g., None, 0, "") as disabling concatenation. Previously, only an explicit False would take this path. This can change behavior for default/unspecified configs by returning DataCollatorForSeq2Seq where the code previously fell through to the multipack/None path.

If the intent was to only diverge on explicit False, keep the identity check:

-            if (
-                not self.cfg.pretraining_sample_concatenation
-                or self.cfg.micro_batch_size > 1
-            ):
+            if (
+                self.cfg.pretraining_sample_concatenation is False
+                or self.cfg.micro_batch_size > 1
+            ):
                 return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)

If the broader behavior is intended, consider documenting the new default semantics and add tests for pretraining_sample_concatenation being None vs False.

Would you like me to add unit tests covering pretraining with pretraining_sample_concatenation set to True/False/None and micro_batch_size 1/>1?

src/axolotl/core/trainers/base.py (2)

276-288: Double-check split vs. dispatch gating and set vs. multiply semantics

  • Gating on both split_batches and dispatch_batches may be overly restrictive. If split_batches=True and dispatch_batches=False, do we still need the adjustment? Please confirm intended accelerate semantics for both flags.
  • In the pretraining + sample_packing path, training/eval branches sometimes set the batch_size to num_processes instead of multiplying. That assumes per-device bs=1. If users set per-device bs>1 for these modes, this will shrink the effective global batch. Consider scaling multiplicatively when batch_size > 1.

If desired, I can add a helper to normalize global vs per-process batch sizing across these modes and log the computed values for traceability.


276-288: Add debug logging for batch-size rewrites to aid troubleshooting

Given the complexity of pretraining + sample packing + split/dispatch interactions, add a short LOG.debug with the before/after batch size, mode, and num_processes whenever this rewrite triggers.

I can add the debug statements in this block if you’d like.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between ecbe8b2 and bb65157.

📒 Files selected for processing (2)
  • src/axolotl/core/builders/causal.py (1 hunks)
  • src/axolotl/core/trainers/base.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/core/builders/causal.py (1)
src/axolotl/integrations/base.py (2)
  • cfg (352-353)
  • cfg (356-357)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit

Comment on lines +276 to +288
if (
self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Critical: Reintroducing batch_size alongside a BatchSampler causes DataLoader errors

This block can set or mutate dataloader_params["batch_size"] even when a BatchSampler is already configured (e.g., eval with sample packing returns a MultipackBatchSampler). PyTorch DataLoader forbids passing both batch_sampler and batch_size. In the eval pretraining + sample_packing + eval_sample_packing=True path, we delete batch_size earlier then re-add it here, leading to a runtime error and/or undefined behavior.

Guard against the presence of batch_sampler before touching batch_size:

         if (
             self.args.accelerator_config is not None
             and self.args.accelerator_config.split_batches
             and self.args.accelerator_config.dispatch_batches
         ):
-            if self.args.sample_packing and self.args.pretraining:
-                if not self.args.eval_sample_packing and not is_training:
-                    dataloader_params["batch_size"] *= self.accelerator.num_processes
-                else:
-                    dataloader_params["batch_size"] = self.accelerator.num_processes
-            elif not self.args.sample_packing and self.args.pretraining:
-                dataloader_params["batch_size"] *= self.accelerator.num_processes
+            # Avoid conflicting with BatchSampler: DataLoader forbids both batch_size and batch_sampler.
+            if "batch_sampler" not in dataloader_params:
+                if self.args.sample_packing and self.args.pretraining:
+                    if not self.args.eval_sample_packing and not is_training:
+                        dataloader_params["batch_size"] *= self.accelerator.num_processes
+                    else:
+                        dataloader_params["batch_size"] = self.accelerator.num_processes
+                elif not self.args.sample_packing and self.args.pretraining:
+                    dataloader_params["batch_size"] *= self.accelerator.num_processes

Follow-up: If you actually need different behavior for the multipack case under split/dispatch, the adjustment must be applied inside the MultipackBatchSampler construction (e.g., its batch_size or batch_max_len), not by reintroducing batch_size here.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (
self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes
if (
self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
# Avoid conflicting with BatchSampler: DataLoader forbids both batch_size and batch_sampler.
if "batch_sampler" not in dataloader_params:
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes

@codecov
Copy link

codecov bot commented Aug 18, 2025

Codecov Report

❌ Patch coverage is 14.28571% with 6 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/core/trainers/base.py 14.28% 6 Missing ⚠️

📢 Thoughts on this report? Let us know!

):
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this only be scaling by the total dp size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accelerator handles the actual dp sharding, so I don't think so? I'll test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants