[Q] Can someone explain the logic behind tbptt_split_batch's splitting dimension? #10086
Unanswered
garrett361
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am confused about the code in
tbptt_split_batch
referenced at the bottom of this post, specifically why it is apparently assumed that thex
infor x in batch
would have their time-dimension be at1
.I would have thought that
tbptt_split_batch
would be designed withtorch
'sRNN
,GRU
, orLSTM
in mind (withbatch_first=True
), in which case I would expectbatch
to be of shape(batch_size, sequence_length, input_size)
. But if this were the case, then thetime_dims = [len(x[0]) for x in batch]
would mistakenly takeinput_size
to be the size of the time dimension and the rest oftbptt_split_batch
would split along this dimension, rather than thesequence_length
dimension, no?Came across the above when attempting to use
tbptt_split_batch
for a custom dataset. Needed to overwrite with the splitting over the dimension which I expected, as outline above, and it seems to work correctly.I feel like I'm badly misunderstanding something.
https://github.com/PyTorchLightning/pytorch-lightning/blob/c9bc10ce8473a2249ffa4e00972c0c3c1d2641c4/pytorch_lightning/core/lightning.py#L1720-L1739
Edit addition for clarity: in the above code I assume that
batch
is a(b, t, d)
shaped input tensor with these three numbers being the batch size, sequence length, and input dimension, respectively.Then, it would seem that
len(x[0])
would gived
, rather than the expectedt
, for eachx
. I would have thought that the above should readwhich is what I have in my own
tbptt_split_batch
methods. I guess different assumptions are being made about the shape of the input tensor?Beta Was this translation helpful? Give feedback.
All reactions