@@ -84,6 +84,7 @@ def test_overfit_batch_limits(tmpdir):
8484 # test train loader applies correct limits
8585 # ------------------------------------------------------
8686 trainer = Trainer (overfit_batches = 4 )
87+ trainer .data_connector .attach_dataloaders (model = model )
8788 trainer .reset_train_dataloader (model )
8889 assert trainer .num_training_batches == 4
8990
@@ -93,6 +94,7 @@ def test_overfit_batch_limits(tmpdir):
9394 assert torch .eq (ya , yb ).all ()
9495
9596 trainer = Trainer (overfit_batches = 0.11 )
97+ trainer .data_connector .attach_dataloaders (model = model )
9698 trainer .reset_train_dataloader (model )
9799 # The dataloader should have been overwritten with a Sequential sampler.
98100 assert trainer .train_dataloader is not train_loader
@@ -111,7 +113,9 @@ def test_overfit_batch_limits(tmpdir):
111113 # ------------------------------------------------------
112114 # test overfit_batches as percent
113115 # ------------------------------------------------------
114- loader_num_batches , dataloaders = Trainer (overfit_batches = 0.11 )._reset_eval_dataloader (split , model = model )
116+ trainer = Trainer (overfit_batches = 0.11 )
117+ trainer .data_connector .attach_dataloaders (model )
118+ loader_num_batches , dataloaders = trainer ._reset_eval_dataloader (split , model = model )
115119 assert loader_num_batches [0 ] == num_train_samples
116120
117121 # make sure we turned off shuffle for the user
@@ -125,23 +129,35 @@ def test_overfit_batch_limits(tmpdir):
125129 # ------------------------------------------------------
126130 # test overfit_batches as int
127131 # ------------------------------------------------------
128- loader_num_batches , dataloaders = Trainer (overfit_batches = 1 )._reset_eval_dataloader (split , model = model )
132+ trainer = Trainer (overfit_batches = 1 )
133+ trainer .data_connector .attach_dataloaders (model )
134+ loader_num_batches , dataloaders = trainer ._reset_eval_dataloader (split , model = model )
129135 assert loader_num_batches [0 ] == 1
130- loader_num_batches , dataloaders = Trainer (overfit_batches = 5 )._reset_eval_dataloader (split , model = model )
136+ trainer = Trainer (overfit_batches = 5 )
137+ trainer .data_connector .attach_dataloaders (model )
138+ loader_num_batches , dataloaders = trainer ._reset_eval_dataloader (split , model = model )
131139 assert loader_num_batches [0 ] == 5
132140
133141 # ------------------------------------------------------
134142 # test limit_xxx_batches as percent AND int
135143 # ------------------------------------------------------
136144 if split == RunningStage .VALIDATING :
137- loader_num_batches , dataloaders = Trainer (limit_val_batches = 0.1 )._reset_eval_dataloader (split , model = model )
145+ trainer = Trainer (limit_val_batches = 0.1 )
146+ trainer .data_connector .attach_dataloaders (model )
147+ loader_num_batches , dataloaders = trainer ._reset_eval_dataloader (split , model = model )
138148 assert loader_num_batches [0 ] == int (0.1 * len (val_loader ))
139149
140- loader_num_batches , dataloaders = Trainer (limit_val_batches = 10 )._reset_eval_dataloader (split , model = model )
150+ trainer = Trainer (limit_val_batches = 10 )
151+ trainer .data_connector .attach_dataloaders (model )
152+ loader_num_batches , dataloaders = trainer ._reset_eval_dataloader (split , model = model )
141153 assert loader_num_batches [0 ] == 10
142154 else :
143- loader_num_batches , dataloaders = Trainer (limit_test_batches = 0.1 )._reset_eval_dataloader (split , model = model )
155+ trainer = Trainer (limit_test_batches = 0.1 )
156+ trainer .data_connector .attach_dataloaders (model )
157+ loader_num_batches , dataloaders = trainer ._reset_eval_dataloader (split , model = model )
144158 assert loader_num_batches [0 ] == int (0.1 * len (test_loader ))
145159
146- loader_num_batches , dataloaders = Trainer (limit_test_batches = 10 )._reset_eval_dataloader (split , model = model )
160+ trainer = Trainer (limit_test_batches = 10 )
161+ trainer .data_connector .attach_dataloaders (model )
162+ loader_num_batches , dataloaders = trainer ._reset_eval_dataloader (split , model = model )
147163 assert loader_num_batches [0 ] == 10
0 commit comments