Skip to content

Commit 403aa4d

Browse files
justusschocklexierule
authored andcommitted
[Bugfix] Detach Loaders after running entrypoint (#8885)
detach loaders after run
1 parent 7239860 commit 403aa4d

File tree

4 files changed

+91
-6
lines changed

4 files changed

+91
-6
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222

2323
## [1.4.1] - 2021-08-03
2424

25+
- Restore original loaders if replaced by entrypoint ([#8885](https://github.com/PyTorchLightning/pytorch-lightning/pull/8885))
26+
2527
- Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601))
2628
- Fixed references for `ResultCollection.extra` ([#8622](https://github.com/PyTorchLightning/pytorch-lightning/pull/8622))
2729
- Fixed reference issues during epoch end result collection ([#8621](https://github.com/PyTorchLightning/pytorch-lightning/pull/8621))

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional, Union
15+
from typing import Callable, Optional, Union
1616

1717
import pytorch_lightning as pl
1818
from pytorch_lightning.trainer.supporters import prefetch_iterator
@@ -117,19 +117,23 @@ def attach_dataloaders(
117117
# functions to overwrite with these implementations
118118
if train_dataloaders is not None:
119119
self.trainer.train_dataloader = None
120-
model.train_dataloader = _PatchDataLoader(train_dataloaders)
120+
train_dataloader = _PatchDataLoader(train_dataloaders, "train")
121+
train_dataloader.patch(model)
121122

122123
if val_dataloaders is not None:
123124
self.trainer.val_dataloaders = None
124-
model.val_dataloader = _PatchDataLoader(val_dataloaders)
125+
val_dataloader = _PatchDataLoader(val_dataloaders, "val")
126+
val_dataloader.patch(model)
125127

126128
if test_dataloaders is not None:
127129
self.trainer.test_dataloaders = None
128-
model.test_dataloader = _PatchDataLoader(test_dataloaders)
130+
test_dataloader = _PatchDataLoader(test_dataloaders, "test")
131+
test_dataloader.patch(model)
129132

130133
if predict_dataloaders is not None:
131134
self.trainer.predict_dataloaders = None
132-
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)
135+
predict_dataloader = _PatchDataLoader(predict_dataloaders, "predict")
136+
predict_dataloader.patch(model)
133137

134138
def attach_datamodule(
135139
self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None
@@ -157,6 +161,13 @@ def attach_datamodule(
157161
if hasattr(datamodule, "data_pipeline"):
158162
model.data_pipeline = datamodule.data_pipeline
159163

164+
@staticmethod
165+
def detach_data(model: "pl.LightningModule") -> None:
166+
for stage in ("train", "val", "test", "predict"):
167+
loader = getattr(model, f"{stage}_dataloader", None)
168+
if isinstance(loader, _PatchDataLoader):
169+
loader.unpatch(model)
170+
160171

161172
class _PatchDataLoader:
162173
r"""
@@ -167,13 +178,23 @@ class _PatchDataLoader:
167178
dataloader: Dataloader object to return when called.
168179
"""
169180

170-
def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None:
181+
def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], stage: str) -> None:
171182
self.dataloader = dataloader
172183

173184
# cannot pickle __code__ so cannot verify if PatchDataloader
174185
# exists which shows dataloader methods have been overwritten.
175186
# so, we hack it by using the string representation
176187
self.patch_loader_code = str(self.__call__.__code__)
188+
self.old_loader: Optional[Callable] = None
189+
self.stage = stage
177190

178191
def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
179192
return self.dataloader
193+
194+
def patch(self, model: "pl.LightningModule") -> None:
195+
self._old_loader = getattr(model, self.stage + "_dataloader")
196+
setattr(model, self.stage + "_dataloader", self)
197+
198+
def unpatch(self, model: "pl.LightningModule") -> None:
199+
setattr(model, self.stage + "_dataloader", self._old_loader)
200+
self._old_loader = None

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,9 @@ def _call_teardown_hook(self, model: "pl.LightningModule") -> None:
11911191
if self.datamodule is not None:
11921192
self.datamodule.teardown(stage=fn)
11931193
self.profiler.teardown(stage=fn)
1194+
1195+
self.data_connector.detach_data(self.lightning_module)
1196+
11941197
self.teardown(stage=fn)
11951198
model.teardown(stage=fn)
11961199

tests/trainer/test_data_loading.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,62 @@ class CustomSampler(Sampler):
254254
dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset))
255255
with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"):
256256
trainer.auto_add_sampler(dataloader, shuffle=True)
257+
258+
259+
def test_loader_detaching():
260+
"""Checks that the loader has been resetted after the entrypoint"""
261+
262+
class LoaderTestModel(BoringModel):
263+
def training_step(self, batch, batch_idx):
264+
assert len(model.train_dataloader()) == 10
265+
return super().training_step(batch, batch_idx)
266+
267+
def validation_step(self, batch, batch_idx):
268+
assert len(model.val_dataloader()) == 10
269+
return super().validation_step(batch, batch_idx)
270+
271+
def test_step(self, batch, batch_idx):
272+
assert len(model.test_dataloader()) == 10
273+
return super().test_step(batch, batch_idx)
274+
275+
def predict_step(self, batch, batch_idx, dataloader_idx=None):
276+
assert len(model.predict_dataloader()) == 10
277+
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
278+
279+
loader = DataLoader(RandomDataset(32, 10), batch_size=1)
280+
281+
model = LoaderTestModel()
282+
283+
assert len(model.train_dataloader()) == 64
284+
assert len(model.val_dataloader()) == 64
285+
assert len(model.predict_dataloader()) == 64
286+
assert len(model.test_dataloader()) == 64
287+
288+
trainer = Trainer(fast_dev_run=1)
289+
trainer.fit(model, loader, loader)
290+
291+
assert len(model.train_dataloader()) == 64
292+
assert len(model.val_dataloader()) == 64
293+
assert len(model.predict_dataloader()) == 64
294+
assert len(model.test_dataloader()) == 64
295+
296+
trainer.validate(model, loader)
297+
298+
assert len(model.train_dataloader()) == 64
299+
assert len(model.val_dataloader()) == 64
300+
assert len(model.predict_dataloader()) == 64
301+
assert len(model.test_dataloader()) == 64
302+
303+
trainer.predict(model, loader)
304+
305+
assert len(model.train_dataloader()) == 64
306+
assert len(model.val_dataloader()) == 64
307+
assert len(model.predict_dataloader()) == 64
308+
assert len(model.test_dataloader()) == 64
309+
310+
trainer.test(model, loader)
311+
312+
assert len(model.train_dataloader()) == 64
313+
assert len(model.val_dataloader()) == 64
314+
assert len(model.predict_dataloader()) == 64
315+
assert len(model.test_dataloader()) == 64

0 commit comments

Comments
 (0)