-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
docsDocumentation relatedDocumentation related
Description
📚 Documentation
The example of TBPTT
https://lightning.ai/docs/pytorch/stable/common/tbptt.html
contains a couple of weird lines with self.optimizer.step()
and self.optimizer.zero_grad()
Also, shouldn't one use self.manual_backward
instead of self.backward
?
Also, in another documentation page you state that calling optimizer.step
right before backward
is preferred and good practice, yet you don't do it here
It would make more sense to write
# 2. Remove the `hiddens` argument
def training_step(self, batch, batch_idx):
# 3. Split the batch in chunks along the time dimension
split_batches = split_batch(batch, self.truncated_bptt_steps)
batch_size = 10
hidden_dim = 20
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
# get optimizer
optimizer = self.optimizers()
for split_batch in range(split_batches):
# 4. Perform the optimization in a loop
loss, hiddens = self.my_rnn(split_batch, hiddens)
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
# 5. "Truncate"
hiddens = hiddens.detach()
# 6. Remove the return of `hiddens`
# Returning loss in manual optimization is not needed
return None
cc @Borda
Metadata
Metadata
Assignees
Labels
docsDocumentation relatedDocumentation related