Extract_batch_size and Custom Batch Objects That Implement __len__ and __iter__. #13170
Unanswered
jonathanking
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
According to the way
extract_batch_size
works, I am unable to use a custom batch object with PyTorch Lightning without explicitly passing the length of the batch to every call toself.log()
(and there are many such calls). Instead, the following error message appears:pytorch_lightning.utilities.exceptions.MisconfigurationException: We could not infer the batch_size from the batch. Either simplify its structure or provide the batch_size as self.log(..., batch_size=batch_size)
.My custom batch object is an iterable and supports the
len()
anditer()
operations. The batch object yields custom data objects (not tensors) when iterated over, and so PyTorch lightning does not work in this case.I believe this is something that should be supported by PyTorch Lightning, especially when custom batch objects implement
iter()
orlen()
. Do you agree?Beta Was this translation helpful? Give feedback.
All reactions