diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index e5bc217621..191ff388e4 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -424,7 +424,7 @@ def build_collator( ): if training_args.pretraining: if ( - self.cfg.pretraining_sample_concatenation is False + not self.cfg.pretraining_sample_concatenation or self.cfg.micro_batch_size > 1 ): return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 0f9f6e4c4d..86f125852e 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -272,6 +272,20 @@ def _get_dataloader( num_workers=self.args.dataloader_num_workers, rank=self.args.process_index, ) + + if ( + self.args.accelerator_config is not None + and self.args.accelerator_config.split_batches + and self.args.accelerator_config.dispatch_batches + ): + if self.args.sample_packing and self.args.pretraining: + if not self.args.eval_sample_packing and not is_training: + dataloader_params["batch_size"] *= self.accelerator.num_processes + else: + dataloader_params["batch_size"] = self.accelerator.num_processes + elif not self.args.sample_packing and self.args.pretraining: + dataloader_params["batch_size"] *= self.accelerator.num_processes + if self.args.sample_packing and ( (is_training and not self.args.pretraining) or (not is_training and self.args.eval_sample_packing is not False) diff --git a/src/axolotl/exception_handling.py b/src/axolotl/exception_handling.py new file mode 100644 index 0000000000..e69de29bb2