2222
2323def test_unbalanced_logging_with_multiple_optimizers (tmpdir ):
2424 """
25- This tests ensures reduction works in un-balanced logging settings
25+ This tests ensures reduction works in unbalanced logging settings,
26+ even when a Callback also logs.
2627 """
2728 class TestModel (BoringModel ):
28-
29- loss_1 = []
30- loss_2 = []
29+ actual = {0 : [], 1 : []}
3130
3231 def training_step (self , batch , batch_idx , optimizer_idx ):
33- output = self .layer (batch )
34- loss = self .loss (batch , output )
35- if optimizer_idx == 0 and self .trainer .global_step > 10 :
36- self .log ("loss_1" , loss , on_epoch = True , prog_bar = True )
37- self .loss_1 .append (loss .detach ().clone ())
38- elif optimizer_idx == 1 :
39- self .log ("loss_2" , loss , on_epoch = True , prog_bar = True )
40- self .loss_2 .append (loss .detach ().clone ())
41- return {"loss" : loss }
32+ out = super ().training_step (batch , batch_idx )
33+ loss = out ["loss" ]
34+ self .log (f"loss_{ optimizer_idx } " , loss , on_epoch = True )
35+ self .actual [optimizer_idx ].append (loss )
36+ return out
4237
4338 def configure_optimizers (self ):
4439 optimizer = torch .optim .SGD (self .layer .parameters (), lr = 0.001 )
@@ -48,16 +43,28 @@ def configure_optimizers(self):
4843 model = TestModel ()
4944 model .training_epoch_end = None
5045
46+ class TestCallback (pl .Callback ):
47+ def on_train_batch_end (self , trainer , pl_module , output , batch , batch_idx , dl_idx ):
48+ # when this is called, the EpochResultStore state has not been reset yet because we are still
49+ # "INSIDE_BATCH_TRAIN_LOOP" and the LoggerConnector runs its `on_train_batch_end` after the
50+ # Callback (see `TrainLoop.on_train_batch_end`). For this reason, opt_idx here is the index
51+ # of the last optimizer updated (the second, index 1). This produced a KeyError as reported in #5459
52+ pl_module .log ("test_train_batch_end" , trainer .logger_connector .cached_results ._opt_idx )
53+
5154 # Initialize a trainer
5255 trainer = pl .Trainer (
5356 default_root_dir = tmpdir ,
5457 max_epochs = 1 ,
58+ limit_train_batches = 5 ,
59+ limit_val_batches = 5 ,
60+ callbacks = [TestCallback ()],
61+ weights_summary = None ,
5562 )
56-
5763 trainer .fit (model )
5864
59- assert torch .equal (trainer .callback_metrics ["loss_2_step" ], model .loss_2 [- 1 ])
60- assert torch .equal (trainer .callback_metrics ["loss_1_step" ], model .loss_1 [- 1 ])
61- # test loss are properly reduced
62- assert torch .abs (trainer .callback_metrics ["loss_2_epoch" ] - torch .FloatTensor (model .loss_2 ).mean ()) < 1e-6
63- assert torch .abs (trainer .callback_metrics ["loss_1_epoch" ] - torch .FloatTensor (model .loss_1 ).mean ()) < 1e-6
65+ for k , v in model .actual .items ():
66+ assert torch .equal (trainer .callback_metrics [f"loss_{ k } _step" ], v [- 1 ])
67+ # test loss is properly reduced
68+ torch .testing .assert_allclose (trainer .callback_metrics [f"loss_{ k } _epoch" ], torch .tensor (v ).mean ())
69+
70+ assert trainer .callback_metrics ["test_train_batch_end" ] == len (model .optimizers ()) - 1
0 commit comments