From 7fd3d8abc436843d0a22cb5025fdfa791ff8eaa0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 16 Aug 2025 22:05:31 -0400 Subject: [PATCH 1/2] handle batch size correchtly when using split and dispatch batches --- src/axolotl/core/trainers/base.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 0f9f6e4c4d..86f125852e 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -272,6 +272,20 @@ def _get_dataloader( num_workers=self.args.dataloader_num_workers, rank=self.args.process_index, ) + + if ( + self.args.accelerator_config is not None + and self.args.accelerator_config.split_batches + and self.args.accelerator_config.dispatch_batches + ): + if self.args.sample_packing and self.args.pretraining: + if not self.args.eval_sample_packing and not is_training: + dataloader_params["batch_size"] *= self.accelerator.num_processes + else: + dataloader_params["batch_size"] = self.accelerator.num_processes + elif not self.args.sample_packing and self.args.pretraining: + dataloader_params["batch_size"] *= self.accelerator.num_processes + if self.args.sample_packing and ( (is_training and not self.args.pretraining) or (not is_training and self.args.eval_sample_packing is not False) From bb65157dcf337bf88828f5a6109978b4b33a219e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 17 Aug 2025 12:49:48 -0400 Subject: [PATCH 2/2] fix conditional for None values --- src/axolotl/core/builders/causal.py | 2 +- src/axolotl/exception_handling.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 src/axolotl/exception_handling.py diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index e5bc217621..191ff388e4 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -424,7 +424,7 @@ def build_collator( ): if training_args.pretraining: if ( - self.cfg.pretraining_sample_concatenation is False + not self.cfg.pretraining_sample_concatenation or self.cfg.micro_batch_size > 1 ): return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) diff --git a/src/axolotl/exception_handling.py b/src/axolotl/exception_handling.py new file mode 100644 index 0000000000..e69de29bb2