File tree Expand file tree Collapse file tree 1 file changed +7
-7
lines changed
docs/source-pytorch/common Expand file tree Collapse file tree 1 file changed +7
-7
lines changed Original file line number Diff line number Diff line change @@ -1074,30 +1074,30 @@ recurrent network trajectories."
10741074
10751075 def __init__ (self ):
10761076 super ().__init__ ()
1077-
1077+
10781078 # 1. Switch to manual optimization
10791079 self .automatic_optimization = False
1080-
1080+
10811081 self .truncated_bptt_steps = 10
10821082 self .my_rnn = ...
1083-
1083+
10841084 # 2. Remove the `hiddens` argument
10851085 def training_step (self , batch , batch_idx ):
1086-
1086+
10871087 # 3. Split the batch in chunks along the time dimension
10881088 split_batches = split_batch(batch, self .truncated_bptt_steps)
1089-
1089+
10901090 hiddens = ... # 3. Choose the initial hidden state
10911091 for split_batch in range (split_batches):
10921092 # 4. Perform the optimization in a loop
10931093 loss, hiddens = self .my_rnn(split_batch, hiddens)
10941094 self .backward(loss)
10951095 optimizer.step()
10961096 optimizer.zero_grad()
1097-
1097+
10981098 # 5. "Truncate"
10991099 hiddens = hiddens.detach()
1100-
1100+
11011101 # 6. Remove the return of `hiddens`
11021102 # Returning loss in manual optimization is not needed
11031103 return None
You can’t perform that action at this time.
0 commit comments