Run train loop inside the train loop #7309
daviddavini
started this conversation in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
TL;DR: During on_train_epoch_start, I want to get the model's output on ALL the training data, as part of some pre-training calculations. I'm asking what the lightning-friendly way to do that is.
In my project, every 10 epochs I select a subset of the full training data, and train only on that subset. During part of the calculation of which subset to use, I compute the model's output on every datapoint in the train dataset.
My question is, what's the best way to do this in pytorch lightning? Currently I have a callback with an "on_train_epoch_start" hook. During this hook, the callback makes its own dataloader from trainer.datamodule.train_dataloader(), and manually iterates over the dataloader, computing the model outputs.
This makes me run into problems with pytorch lightning. For instance, when training on the GPU, I get an error, since my callback is using its own dataloader, not the trainer's dataloader, and so it isn't on the GPU. However, I can't use the trainer's dataloader, since after my callback selects its subset, it changes the trainer's dataloader to be just that subset, instead of the full dataset.
I guess I have two main questions:
Beta Was this translation helpful? Give feedback.
All reactions