1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Optional , Union
15+ from typing import Callable , Optional , Union
1616
1717import pytorch_lightning as pl
1818from pytorch_lightning .trainer .supporters import prefetch_iterator
@@ -117,19 +117,23 @@ def attach_dataloaders(
117117 # functions to overwrite with these implementations
118118 if train_dataloaders is not None :
119119 self .trainer .train_dataloader = None
120- model .train_dataloader = _PatchDataLoader (train_dataloaders )
120+ train_dataloader = _PatchDataLoader (train_dataloaders , "train" )
121+ train_dataloader .patch (model )
121122
122123 if val_dataloaders is not None :
123124 self .trainer .val_dataloaders = None
124- model .val_dataloader = _PatchDataLoader (val_dataloaders )
125+ val_dataloader = _PatchDataLoader (val_dataloaders , "val" )
126+ val_dataloader .patch (model )
125127
126128 if test_dataloaders is not None :
127129 self .trainer .test_dataloaders = None
128- model .test_dataloader = _PatchDataLoader (test_dataloaders )
130+ test_dataloader = _PatchDataLoader (test_dataloaders , "test" )
131+ test_dataloader .patch (model )
129132
130133 if predict_dataloaders is not None :
131134 self .trainer .predict_dataloaders = None
132- model .predict_dataloader = _PatchDataLoader (predict_dataloaders )
135+ predict_dataloader = _PatchDataLoader (predict_dataloaders , "predict" )
136+ predict_dataloader .patch (model )
133137
134138 def attach_datamodule (
135139 self , model : "pl.LightningModule" , datamodule : Optional ["pl.LightningDataModule" ] = None
@@ -157,6 +161,13 @@ def attach_datamodule(
157161 if hasattr (datamodule , "data_pipeline" ):
158162 model .data_pipeline = datamodule .data_pipeline
159163
164+ @staticmethod
165+ def detach_data (model : "pl.LightningModule" ) -> None :
166+ for stage in ("train" , "val" , "test" , "predict" ):
167+ loader = getattr (model , f"{ stage } _dataloader" , None )
168+ if isinstance (loader , _PatchDataLoader ):
169+ loader .unpatch (model )
170+
160171
161172class _PatchDataLoader :
162173 r"""
@@ -167,13 +178,23 @@ class _PatchDataLoader:
167178 dataloader: Dataloader object to return when called.
168179 """
169180
170- def __init__ (self , dataloader : Union [TRAIN_DATALOADERS , EVAL_DATALOADERS ]) -> None :
181+ def __init__ (self , dataloader : Union [TRAIN_DATALOADERS , EVAL_DATALOADERS ], stage : str ) -> None :
171182 self .dataloader = dataloader
172183
173184 # cannot pickle __code__ so cannot verify if PatchDataloader
174185 # exists which shows dataloader methods have been overwritten.
175186 # so, we hack it by using the string representation
176187 self .patch_loader_code = str (self .__call__ .__code__ )
188+ self .old_loader : Optional [Callable ] = None
189+ self .stage = stage
177190
178191 def __call__ (self ) -> Union [TRAIN_DATALOADERS , EVAL_DATALOADERS ]:
179192 return self .dataloader
193+
194+ def patch (self , model : "pl.LightningModule" ) -> None :
195+ self ._old_loader = getattr (model , self .stage + "_dataloader" )
196+ setattr (model , self .stage + "_dataloader" , self )
197+
198+ def unpatch (self , model : "pl.LightningModule" ) -> None :
199+ setattr (model , self .stage + "_dataloader" , self ._old_loader )
200+ self ._old_loader = None
0 commit comments