Skip to content

Commit d5acc0c

Browse files
ananthsublexierule
authored andcommitted
Fix truncated backprop through time when set on LightningModule and not Trainer (#8804)
* Fix truncated backprop through time set on LightningModule and not Trainer (cherry picked from commit c4a1c8b)
1 parent 9e44c65 commit d5acc0c

File tree

5 files changed

+177
-149
lines changed

5 files changed

+177
-149
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121

2222

2323

24+
- Fixed truncated backprop through time enablement when set as a property on the LightningModule and not the Trainer ([#8804](https://github.com/PyTorchLightning/pytorch-lightning/pull/8804/))
25+
26+
2427
## [1.4.0] - 2021-07-27
2528

2629
### Added

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -473,11 +473,13 @@ def _tbptt_split_batch(self, batch: Any) -> List[Any]:
473473
Args:
474474
batch: the current batch to split
475475
"""
476-
splits = [batch]
477-
if self.trainer.truncated_bptt_steps is not None:
478-
model_ref = self.trainer.lightning_module
479-
with self.trainer.profiler.profile("tbptt_split_batch"):
480-
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
476+
tbptt_steps = self._truncated_bptt_steps()
477+
if tbptt_steps == 0:
478+
return [batch]
479+
480+
model_ref = self.trainer.lightning_module
481+
with self.trainer.profiler.profile("tbptt_split_batch"):
482+
splits = model_ref.tbptt_split_batch(batch, tbptt_steps)
481483
return splits
482484

483485
def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:

tests/models/test_cpu.py

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import os
1515

16-
import pytest
1716
import torch
1817

1918
import tests.helpers.pipelines as tpipes
@@ -311,76 +310,3 @@ def test_all_features_cpu_model(tmpdir):
311310
model = BoringModel()
312311

313312
tpipes.run_model_test(trainer_options, model, on_gpu=False, min_acc=0.01)
314-
315-
316-
@pytest.mark.parametrize("n_hidden_states", [1, 2])
317-
def test_tbptt_cpu_model(tmpdir, n_hidden_states):
318-
"""Test truncated back propagation through time works."""
319-
truncated_bptt_steps = 2
320-
sequence_size = 30
321-
batch_size = 30
322-
323-
x_seq = torch.rand(batch_size, sequence_size, 1)
324-
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
325-
326-
class MockSeq2SeqDataset(torch.utils.data.Dataset):
327-
def __getitem__(self, i):
328-
return x_seq, y_seq_list
329-
330-
def __len__(self):
331-
return 1
332-
333-
class BpttTestModel(BoringModel):
334-
def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args, **kwargs):
335-
super().__init__(*args, **kwargs)
336-
self.test_hidden = None
337-
self.batch_size = batch_size
338-
self.layer = torch.nn.Linear(in_features, out_features)
339-
self.n_hidden_states = n_hidden_states
340-
341-
def training_step(self, batch, batch_idx, hiddens):
342-
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
343-
if self.n_hidden_states == 1:
344-
self.test_hidden = torch.rand(1)
345-
else:
346-
self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states)
347-
348-
x_tensor, y_list = batch
349-
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
350-
351-
y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
352-
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"
353-
354-
pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
355-
loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps))
356-
return {"loss": loss_val, "hiddens": self.test_hidden}
357-
358-
def training_epoch_end(self, training_step_outputs):
359-
training_step_outputs = training_step_outputs[0]
360-
assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps)
361-
loss = torch.stack([x["loss"] for x in training_step_outputs]).mean()
362-
self.log("train_loss", loss)
363-
364-
def train_dataloader(self):
365-
return torch.utils.data.DataLoader(
366-
dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False, sampler=None
367-
)
368-
369-
model = BpttTestModel(
370-
batch_size=batch_size,
371-
in_features=truncated_bptt_steps,
372-
out_features=truncated_bptt_steps,
373-
n_hidden_states=n_hidden_states,
374-
)
375-
model.example_input_array = torch.randn(5, truncated_bptt_steps)
376-
377-
# fit model
378-
trainer = Trainer(
379-
default_root_dir=tmpdir,
380-
max_epochs=1,
381-
truncated_bptt_steps=truncated_bptt_steps,
382-
limit_val_batches=0,
383-
weights_summary=None,
384-
)
385-
trainer.fit(model)
386-
assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}"
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import torch
17+
18+
from pytorch_lightning import Trainer
19+
from tests.helpers import BoringModel
20+
21+
22+
@pytest.mark.parametrize("n_hidden_states", (1, 2))
23+
@pytest.mark.parametrize("property_on_module", (False, True))
24+
def test_tbptt_cpu_model(tmpdir, n_hidden_states, property_on_module):
25+
"""Test truncated back propagation through time works."""
26+
truncated_bptt_steps = 2
27+
sequence_size = 30
28+
batch_size = 30
29+
30+
x_seq = torch.rand(batch_size, sequence_size, 1)
31+
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
32+
33+
class MockSeq2SeqDataset(torch.utils.data.Dataset):
34+
def __getitem__(self, i):
35+
return x_seq, y_seq_list
36+
37+
def __len__(self):
38+
return 1
39+
40+
class BpttTestModel(BoringModel):
41+
def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args, **kwargs):
42+
super().__init__(*args, **kwargs)
43+
self.test_hidden = None
44+
self.batch_size = batch_size
45+
self.layer = torch.nn.Linear(in_features, out_features)
46+
self.n_hidden_states = n_hidden_states
47+
if property_on_module:
48+
self.truncated_bptt_steps = truncated_bptt_steps
49+
50+
def training_step(self, batch, batch_idx, hiddens):
51+
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
52+
if self.n_hidden_states == 1:
53+
self.test_hidden = torch.rand(1)
54+
else:
55+
self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states)
56+
57+
x_tensor, y_list = batch
58+
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
59+
60+
y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
61+
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"
62+
63+
pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
64+
loss_val = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps))
65+
return {"loss": loss_val, "hiddens": self.test_hidden}
66+
67+
def training_epoch_end(self, training_step_outputs):
68+
training_step_outputs = training_step_outputs[0]
69+
assert len(training_step_outputs) == (sequence_size / truncated_bptt_steps)
70+
loss = torch.stack([x["loss"] for x in training_step_outputs]).mean()
71+
self.log("train_loss", loss)
72+
73+
def train_dataloader(self):
74+
return torch.utils.data.DataLoader(
75+
dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False, sampler=None
76+
)
77+
78+
model = BpttTestModel(
79+
batch_size=batch_size,
80+
in_features=truncated_bptt_steps,
81+
out_features=truncated_bptt_steps,
82+
n_hidden_states=n_hidden_states,
83+
)
84+
model.example_input_array = torch.randn(5, truncated_bptt_steps)
85+
86+
trainer_tbptt_steps = None if property_on_module else truncated_bptt_steps
87+
88+
# fit model
89+
trainer = Trainer(
90+
default_root_dir=tmpdir,
91+
max_epochs=1,
92+
truncated_bptt_steps=trainer_tbptt_steps,
93+
limit_val_batches=0,
94+
weights_summary=None,
95+
)
96+
trainer.fit(model)
97+
assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}"
98+
99+
100+
def test_tbptt_log(tmpdir):
101+
truncated_bptt_steps = 2
102+
N, T, F = 32, 15, 1 # batches x timesteps (sequence size) x features
103+
batch_size = 10
104+
assert T % truncated_bptt_steps != 0, "Should test leftover time steps"
105+
106+
class MockSeq2SeqDataset(torch.utils.data.Dataset):
107+
def __init__(self):
108+
self.x_seq = torch.randn(N, T, F)
109+
self.y_seq = torch.randn(N, T, F)
110+
111+
def __getitem__(self, index):
112+
return self.x_seq[index], self.y_seq[index]
113+
114+
def __len__(self):
115+
return N
116+
117+
class TestModel(BoringModel):
118+
def __init__(self):
119+
super().__init__()
120+
self.test_hidden = None
121+
self.layer = torch.nn.LSTM(input_size=F, hidden_size=T, batch_first=True)
122+
self.truncated_bptt_steps = truncated_bptt_steps
123+
124+
def training_step(self, batch, batch_idx, hiddens):
125+
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
126+
if hiddens is not None:
127+
assert hiddens.grad_fn is None
128+
split_idx = self.trainer.fit_loop.split_idx
129+
self.test_hidden = torch.tensor(split_idx, requires_grad=True, dtype=torch.float).pow(2)
130+
131+
x, y = batch
132+
if self.trainer.fit_loop.epoch_loop.batch_loop.done:
133+
# last split idx, not aligned
134+
assert x.shape[1] == T % truncated_bptt_steps
135+
assert y.shape[1] == T % truncated_bptt_steps
136+
else:
137+
assert x.shape[1] == truncated_bptt_steps
138+
assert y.shape[1] == truncated_bptt_steps
139+
140+
pred, _ = self(x)
141+
loss = torch.nn.functional.mse_loss(pred, y)
142+
143+
self.log("a", loss, on_epoch=True)
144+
145+
return {"loss": loss, "hiddens": self.test_hidden}
146+
147+
def on_train_batch_start(self, *args, **kwargs) -> None:
148+
self.test_hidden = None
149+
150+
def train_dataloader(self):
151+
return torch.utils.data.DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size)
152+
153+
model = TestModel()
154+
model.training_epoch_end = None
155+
156+
trainer = Trainer(
157+
default_root_dir=tmpdir,
158+
limit_val_batches=0,
159+
max_epochs=2,
160+
log_every_n_steps=2,
161+
weights_summary=None,
162+
)
163+
trainer.fit(model)
164+
165+
assert trainer.fit_loop.batch_idx == N // batch_size
166+
assert trainer.fit_loop.split_idx == T // truncated_bptt_steps
167+
assert set(trainer.logged_metrics) == {"a_step", "a_epoch", "epoch"}

tests/trainer/logging_/test_train_loop_logging.py

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -213,76 +213,6 @@ def validation_step(self, batch, batch_idx):
213213
assert trainer.logged_metrics["bar"] == result
214214

215215

216-
def test_tbptt_log(tmpdir):
217-
truncated_bptt_steps = 2
218-
N, T, F = 32, 15, 1 # batches x timesteps (sequence size) x features
219-
batch_size = 10
220-
assert T % truncated_bptt_steps != 0, "Should test leftover time steps"
221-
222-
class MockSeq2SeqDataset(torch.utils.data.Dataset):
223-
def __init__(self):
224-
self.x_seq = torch.randn(N, T, F)
225-
self.y_seq = torch.randn(N, T, F)
226-
227-
def __getitem__(self, index):
228-
return self.x_seq[index], self.y_seq[index]
229-
230-
def __len__(self):
231-
return N
232-
233-
class TestModel(BoringModel):
234-
def __init__(self):
235-
super().__init__()
236-
self.test_hidden = None
237-
self.layer = torch.nn.LSTM(input_size=F, hidden_size=T, batch_first=True)
238-
239-
def training_step(self, batch, batch_idx, hiddens):
240-
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
241-
if hiddens is not None:
242-
assert hiddens.grad_fn is None
243-
split_idx = self.trainer.fit_loop.split_idx
244-
self.test_hidden = torch.tensor(split_idx, requires_grad=True, dtype=torch.float).pow(2)
245-
246-
x, y = batch
247-
if self.trainer.fit_loop.epoch_loop.batch_loop.done:
248-
# last split idx, not aligned
249-
assert x.shape[1] == T % truncated_bptt_steps
250-
assert y.shape[1] == T % truncated_bptt_steps
251-
else:
252-
assert x.shape[1] == truncated_bptt_steps
253-
assert y.shape[1] == truncated_bptt_steps
254-
255-
pred, _ = self(x)
256-
loss = torch.nn.functional.mse_loss(pred, y)
257-
258-
self.log("a", loss, on_epoch=True)
259-
260-
return {"loss": loss, "hiddens": self.test_hidden}
261-
262-
def on_train_batch_start(self, *args, **kwargs) -> None:
263-
self.test_hidden = None
264-
265-
def train_dataloader(self):
266-
return torch.utils.data.DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size)
267-
268-
model = TestModel()
269-
model.training_epoch_end = None
270-
271-
trainer = Trainer(
272-
default_root_dir=tmpdir,
273-
limit_val_batches=0,
274-
truncated_bptt_steps=truncated_bptt_steps,
275-
max_epochs=2,
276-
log_every_n_steps=2,
277-
weights_summary=None,
278-
)
279-
trainer.fit(model)
280-
281-
assert trainer.fit_loop.batch_idx == N // batch_size
282-
assert trainer.fit_loop.split_idx == T // truncated_bptt_steps
283-
assert set(trainer.logged_metrics) == {"a_step", "a_epoch", "epoch"}
284-
285-
286216
def test_different_batch_types_for_sizing(tmpdir):
287217
class TestModel(BoringModel):
288218
def training_step(self, batch, batch_idx):

0 commit comments

Comments
 (0)