-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
handle batch size correchtly when using split and dispatch batches #3076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughUpdates pretraining handling in the causal collator condition and introduces new batch_size adjustment logic in AxolotlTrainer._get_dataloader for specific accelerator configurations during pretraining, with branching based on sample packing and eval/training mode. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
src/axolotl/core/builders/causal.py (1)
425-433: Truthiness change may alter pretraining collator selection for None/unspecified valuesSwitching from
is Falsetonot self.cfg.pretraining_sample_concatenationbroadens the condition to treat any falsy value (e.g.,None,0,"") as disabling concatenation. Previously, only an explicitFalsewould take this path. This can change behavior for default/unspecified configs by returningDataCollatorForSeq2Seqwhere the code previously fell through to the multipack/None path.If the intent was to only diverge on explicit
False, keep the identity check:- if ( - not self.cfg.pretraining_sample_concatenation - or self.cfg.micro_batch_size > 1 - ): + if ( + self.cfg.pretraining_sample_concatenation is False + or self.cfg.micro_batch_size > 1 + ): return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)If the broader behavior is intended, consider documenting the new default semantics and add tests for
pretraining_sample_concatenationbeingNonevsFalse.Would you like me to add unit tests covering pretraining with
pretraining_sample_concatenationset to True/False/None and micro_batch_size 1/>1?src/axolotl/core/trainers/base.py (2)
276-288: Double-check split vs. dispatch gating and set vs. multiply semantics
- Gating on both
split_batchesanddispatch_batchesmay be overly restrictive. Ifsplit_batches=Trueanddispatch_batches=False, do we still need the adjustment? Please confirm intended accelerate semantics for both flags.- In the pretraining + sample_packing path, training/eval branches sometimes set the
batch_sizetonum_processesinstead of multiplying. That assumes per-device bs=1. If users set per-device bs>1 for these modes, this will shrink the effective global batch. Consider scaling multiplicatively whenbatch_size > 1.If desired, I can add a helper to normalize global vs per-process batch sizing across these modes and log the computed values for traceability.
276-288: Add debug logging for batch-size rewrites to aid troubleshootingGiven the complexity of pretraining + sample packing + split/dispatch interactions, add a short LOG.debug with the before/after batch size, mode, and num_processes whenever this rewrite triggers.
I can add the debug statements in this block if you’d like.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
src/axolotl/core/builders/causal.py(1 hunks)src/axolotl/core/trainers/base.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/core/builders/causal.py (1)
src/axolotl/integrations/base.py (2)
cfg(352-353)cfg(356-357)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
| if ( | ||
| self.args.accelerator_config is not None | ||
| and self.args.accelerator_config.split_batches | ||
| and self.args.accelerator_config.dispatch_batches | ||
| ): | ||
| if self.args.sample_packing and self.args.pretraining: | ||
| if not self.args.eval_sample_packing and not is_training: | ||
| dataloader_params["batch_size"] *= self.accelerator.num_processes | ||
| else: | ||
| dataloader_params["batch_size"] = self.accelerator.num_processes | ||
| elif not self.args.sample_packing and self.args.pretraining: | ||
| dataloader_params["batch_size"] *= self.accelerator.num_processes | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Reintroducing batch_size alongside a BatchSampler causes DataLoader errors
This block can set or mutate dataloader_params["batch_size"] even when a BatchSampler is already configured (e.g., eval with sample packing returns a MultipackBatchSampler). PyTorch DataLoader forbids passing both batch_sampler and batch_size. In the eval pretraining + sample_packing + eval_sample_packing=True path, we delete batch_size earlier then re-add it here, leading to a runtime error and/or undefined behavior.
Guard against the presence of batch_sampler before touching batch_size:
if (
self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
- if self.args.sample_packing and self.args.pretraining:
- if not self.args.eval_sample_packing and not is_training:
- dataloader_params["batch_size"] *= self.accelerator.num_processes
- else:
- dataloader_params["batch_size"] = self.accelerator.num_processes
- elif not self.args.sample_packing and self.args.pretraining:
- dataloader_params["batch_size"] *= self.accelerator.num_processes
+ # Avoid conflicting with BatchSampler: DataLoader forbids both batch_size and batch_sampler.
+ if "batch_sampler" not in dataloader_params:
+ if self.args.sample_packing and self.args.pretraining:
+ if not self.args.eval_sample_packing and not is_training:
+ dataloader_params["batch_size"] *= self.accelerator.num_processes
+ else:
+ dataloader_params["batch_size"] = self.accelerator.num_processes
+ elif not self.args.sample_packing and self.args.pretraining:
+ dataloader_params["batch_size"] *= self.accelerator.num_processesFollow-up: If you actually need different behavior for the multipack case under split/dispatch, the adjustment must be applied inside the MultipackBatchSampler construction (e.g., its batch_size or batch_max_len), not by reintroducing batch_size here.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if ( | |
| self.args.accelerator_config is not None | |
| and self.args.accelerator_config.split_batches | |
| and self.args.accelerator_config.dispatch_batches | |
| ): | |
| if self.args.sample_packing and self.args.pretraining: | |
| if not self.args.eval_sample_packing and not is_training: | |
| dataloader_params["batch_size"] *= self.accelerator.num_processes | |
| else: | |
| dataloader_params["batch_size"] = self.accelerator.num_processes | |
| elif not self.args.sample_packing and self.args.pretraining: | |
| dataloader_params["batch_size"] *= self.accelerator.num_processes | |
| if ( | |
| self.args.accelerator_config is not None | |
| and self.args.accelerator_config.split_batches | |
| and self.args.accelerator_config.dispatch_batches | |
| ): | |
| # Avoid conflicting with BatchSampler: DataLoader forbids both batch_size and batch_sampler. | |
| if "batch_sampler" not in dataloader_params: | |
| if self.args.sample_packing and self.args.pretraining: | |
| if not self.args.eval_sample_packing and not is_training: | |
| dataloader_params["batch_size"] *= self.accelerator.num_processes | |
| else: | |
| dataloader_params["batch_size"] = self.accelerator.num_processes | |
| elif not self.args.sample_packing and self.args.pretraining: | |
| dataloader_params["batch_size"] *= self.accelerator.num_processes |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| ): | ||
| if self.args.sample_packing and self.args.pretraining: | ||
| if not self.args.eval_sample_packing and not is_training: | ||
| dataloader_params["batch_size"] *= self.accelerator.num_processes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this only be scaling by the total dp size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accelerator handles the actual dp sharding, so I don't think so? I'll test.
The batch size handling when using split batches isn't quite right. This sets the correct batch size based on how we handle batches when pretraining and needing to scale it properly for dispatch_batches/split_batches
Summary by CodeRabbit