How to pass additional parameters to training_step/validation step? #5350
Unanswered
laughingrice
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have a loss function that depends on an additional tensor.
Something similar to:
I tried adding A to the class as a hyperparameter (not great), but that crashes the training on checkpointing saying that not all tensors are on the same device (found at least two devices, cuda:2 and cpu!)
It also doesn't let me run on multiple GPUs as A is a sparse matrix and won't pickle
Is there a way to pass additional parameters to the training_step/validation step? maybe override torch.utils.data.DataLoader to pass a triplet instead of a doublet? - how would I go about taking care of passing things to the right devices for multi GPU training?
Beta Was this translation helpful? Give feedback.
All reactions