Replies: 2 comments
-
This is because Lightning can't figure out how to transfer a custom batch. But if you define a class CustomBatch:
...
def to(self, device):
self.x = self.x.to(device)
self.y = self.y.to(device)
... |
Beta Was this translation helpful? Give feedback.
0 replies
-
Thanks! that is exactly what I was looking for :) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi there!
In my collate_fn function, I am trying to return my own custom batch class instead of a list of tensors. Problem is that the Trainer will not send the CustomBatch tensors to the correct device
If instead, I return a list of tensors then it works fine.
Do I need to add any special method to
CustomBatch
so theTrainer
knows how to manage its content.Best,
Arturo
Beta Was this translation helpful? Give feedback.
All reactions