diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index be3c8d4f5b..69358d018d 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -4134,9 +4134,9 @@ def test_training_works(self, model_cls, custom_module_cls, custom_lora_cls): config._register_custom_module({custom_module_cls: custom_lora_cls}) peft_model = get_peft_model(model, config) - sd_before = peft_model.state_dict() + sd_before = copy.deepcopy(peft_model.state_dict()) inputs = torch.randn(16, 10) - optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-1) + optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-4) for _ in range(5): optimizer.zero_grad() @@ -4146,6 +4146,13 @@ def test_training_works(self, model_cls, custom_module_cls, custom_lora_cls): optimizer.step() sd_after = peft_model.state_dict() + + # sanity check that for finite results, since nan != nan, which would make the test pass trivially + for val in sd_before.values(): + assert torch.isfinite(val).all() + for val in sd_after.values(): + assert torch.isfinite(val).all() + assert not torch.allclose( sd_before["base_model.model.my_module.lora_A.default.weight"], sd_after["base_model.model.my_module.lora_A.default.weight"],