Use lightning to run a single X * batch-sized fit step(s)? #13134
Unanswered
aniongithub
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Currently, calling
trainer.fit
runs the trainer until the specified training conditions are met, or early exit is indicated.I'd like to use pytorch-lightning as part of a larger, graph-like training and inference data flow. This means an epoch or more long blocking
fit
call wouldn't really work for my use-case, because my orchestrator might run some other code both before theLightningModule
/LightningDataModule
(perhaps producing/consuming the inputs and outputs of the module).Is there a way to run this such that each call to pytorch-lightning runs one step of the under-the-hood steps shown here? This would let me orchestrate the larger dataflow while invoking lightning to deal with distributed training of each batch.
Is using fast_dev_run the correct way of doing this? Can I simply keep calling:
trainer.fit(...)
on a trainer created with
fast_dev_run=X
in a loop, where X =trainer.accumulate_grad_batches
whenever my orchestrator reaches a pytorch-lightning node that needs to run one training step inside my larger workflow?If not, what other options are available to me to do this?
Thanks in advance!
Edit: Trying it out, It appears that
fast_dev_run
callssetup
for every call tofit
. I was able to prevent re-initialization of my datasets using a flag, but regardless of that,I never see loss or accuracy change over the course of many iterations. So it appears
fast_dev_run
isn't what I'm looking for.Beta Was this translation helpful? Give feedback.
All reactions