Skip to content

Documentation Example of Truncated BPTT is not working; self.optimizer.step() makes no sense #20517

@simon-bachhuber

Description

@simon-bachhuber

📚 Documentation

The example of TBPTT
https://lightning.ai/docs/pytorch/stable/common/tbptt.html
contains a couple of weird lines with self.optimizer.step() and self.optimizer.zero_grad()

Also, shouldn't one use self.manual_backward instead of self.backward ?

Also, in another documentation page you state that calling optimizer.step right before backward is preferred and good practice, yet you don't do it here

It would make more sense to write

    # 2. Remove the `hiddens` argument
    def training_step(self, batch, batch_idx):

        # 3. Split the batch in chunks along the time dimension
        split_batches = split_batch(batch, self.truncated_bptt_steps)

        batch_size = 10
        hidden_dim = 20
        hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
        # get optimizer
        optimizer = self.optimizers()
        for split_batch in range(split_batches):
            # 4. Perform the optimization in a loop
            loss, hiddens = self.my_rnn(split_batch, hiddens)
            optimizer.zero_grad()
            self.manual_backward(loss)
            optimizer.step()
            
            # 5. "Truncate"
            hiddens = hiddens.detach()

        # 6. Remove the return of `hiddens`
        # Returning loss in manual optimization is not needed
        return None

cc @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    docsDocumentation related

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions