Skip to content

Conversation

@jariskueken
Copy link

@jariskueken jariskueken commented Nov 12, 2025

Enhanced Pre-Training Pipeline

The main changes in this PR are:

  • refactoring of the pre-training pipeline into separate Trainer class
  • support for DDP
  • slightly adapted logging during training
  • refactored project structure a little bit, moved all relevant training scripts into scripts/ added a simple, slimmer training config pretrain_classification_new.py
  • added hydra support for better config handling

This is still WIP and there are still issues specifically with DDP which are due to the current structure of the DataLoader/PriorDataSets which makes it really difficult to properly handle DDP.

@AlexanderPfefferle the problem imo right now is that the datasets itself handle batching. This leads to the following issue: when training with DDP normally one would use something like a distributed sampler to split a batch size of size b into b/m smaller batches where m is the number of GPUs we are using. However because our datasets already return batched elements the internal DataLoader inside of the data pipeline essentially has a batch_size of None because the underlying dataset already returns batched elements therefore we can't to batch splitting.
As far as I see it right now we could change it in two ways:

  • either the Dataset object (pre-loaded or prior) returns only one dataset at a time. This however makes loading way more expensive as our stored data/the data from the prior usually already comes in a batched format. In that setting we would leave the batching entirely to the DataLoader inside of the Trainer class.
  • or we somehow incorporate the distributed loading into the PriorDataLoaders itself though I think this would really complicate things as we would need to handle this stuff outside of the Trainer which I would not want to do.

At least for on-the-fly generation this entire thing is not really a problem as one could just change the seed for each machine inside the PriorDataset therefore each machine would draw different datasets from the Prior anyway and you would increase the effective batch_size to b*m

I think we should find a solution for that before merging and also adapt the entire structure of the PriorDataLoader to be consistent with the rest of the project.

Open TODOs

  • add hydra to dependencies
  • better handling of artifacts (storing and loading, right now I just copied from the old training code)
  • maybe we want more flexibility in loading different model architectures already but also this PR shouldn't become to large probably.
  • better DDP sampling for datasets
  • documentation for trainer class

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant