How to put all but some vars to GPU #7725
-
By default, in Lightning everything that is returned by a dataset is collated by the data loader and shipped to the same device. Something like def ship_batch(self, batch):
batch[0] = batch[0].to(self.device)
# ...
batch[2] = batch[2].cpu() # just for illustration here |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Dear @Haydnspass, You have several ways to do this:
Best, |
Beta Was this translation helpful? Give feedback.
Dear @Haydnspass,
You have several ways to do this:
Create a custom Data / Batch Object and implement the .to function to move only what is required.
Simpler: Override LightningModule.transfer_batch_to_device hook and add your own logic to move only x, y to the right device.
Best,
T.C