From 65808547a076fcfe0df0d42ae29cd6eb0099f953 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 24 Mar 2025 15:11:32 +0100 Subject: [PATCH] FIX Faulty test that results in nan weights This specific test used a learning rate that is too high, resulting in nan weights. Then, when weights are compared to assert that they're different, the test passes trivially because nan != nan. The lr is now reduced and there is a sanity check that none of the weights contain non-finite values. See discussion in https://github.com/huggingface/peft/pull/2433#issuecomment-2747800312 ff. --- tests/test_custom_models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index be3c8d4f5b..69358d018d 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -4134,9 +4134,9 @@ def test_training_works(self, model_cls, custom_module_cls, custom_lora_cls): config._register_custom_module({custom_module_cls: custom_lora_cls}) peft_model = get_peft_model(model, config) - sd_before = peft_model.state_dict() + sd_before = copy.deepcopy(peft_model.state_dict()) inputs = torch.randn(16, 10) - optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-1) + optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-4) for _ in range(5): optimizer.zero_grad() @@ -4146,6 +4146,13 @@ def test_training_works(self, model_cls, custom_module_cls, custom_lora_cls): optimizer.step() sd_after = peft_model.state_dict() + + # sanity check that for finite results, since nan != nan, which would make the test pass trivially + for val in sd_before.values(): + assert torch.isfinite(val).all() + for val in sd_after.values(): + assert torch.isfinite(val).all() + assert not torch.allclose( sd_before["base_model.model.my_module.lora_A.default.weight"], sd_after["base_model.model.my_module.lora_A.default.weight"],