Skip to content

Commit efae604

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1fc2cfa commit efae604

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

docs/source-pytorch/common/lightning_module.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,30 +1074,30 @@ recurrent network trajectories."
10741074
10751075
def __init__(self):
10761076
super().__init__()
1077-
1077+
10781078
# 1. Switch to manual optimization
10791079
self.automatic_optimization = False
1080-
1080+
10811081
self.truncated_bptt_steps = 10
10821082
self.my_rnn = ...
1083-
1083+
10841084
# 2. Remove the `hiddens` argument
10851085
def training_step(self, batch, batch_idx):
1086-
1086+
10871087
# 3. Split the batch in chunks along the time dimension
10881088
split_batches = split_batch(batch, self.truncated_bptt_steps)
1089-
1089+
10901090
hiddens = ... # 3. Choose the initial hidden state
10911091
for split_batch in range(split_batches):
10921092
# 4. Perform the optimization in a loop
10931093
loss, hiddens = self.my_rnn(split_batch, hiddens)
10941094
self.backward(loss)
10951095
optimizer.step()
10961096
optimizer.zero_grad()
1097-
1097+
10981098
# 5. "Truncate"
10991099
hiddens = hiddens.detach()
1100-
1100+
11011101
# 6. Remove the return of `hiddens`
11021102
# Returning loss in manual optimization is not needed
11031103
return None

0 commit comments

Comments
 (0)