diff --git a/tests/tests_pytorch/callbacks/test_lr_monitor.py b/tests/tests_pytorch/callbacks/test_lr_monitor.py index 4f3a20ce82de5..66ce47f0e7ad4 100644 --- a/tests/tests_pytorch/callbacks/test_lr_monitor.py +++ b/tests/tests_pytorch/callbacks/test_lr_monitor.py @@ -548,10 +548,10 @@ def finetune_function(self, pl_module, epoch: int, optimizer): """Called when the epoch begins.""" if epoch == 1 and isinstance(optimizer, torch.optim.SGD): self.unfreeze_and_add_param_group(pl_module.backbone[0], optimizer, lr=0.1) - if epoch == 2 and isinstance(optimizer, torch.optim.Adam): + if epoch == 2 and type(optimizer) is torch.optim.Adam: self.unfreeze_and_add_param_group(pl_module.layer, optimizer, lr=0.1) - if epoch == 3 and isinstance(optimizer, torch.optim.Adam): + if epoch == 3 and type(optimizer) is torch.optim.Adam: assert len(optimizer.param_groups) == 2 self.unfreeze_and_add_param_group(pl_module.backbone[1], optimizer, lr=0.1) assert len(optimizer.param_groups) == 3