Skip to content

Commit cb726a7

Browse files
committed
FIX Faulty test that results in nan weights (#2448)
This specific test used a learning rate that is too high, resulting in nan weights. Then, when weights are compared to assert that they're different, the test passes trivially because nan != nan. The lr is now reduced and there is a sanity check that none of the weights contain non-finite values. See discussion in huggingface/peft#2433 (comment) ff.
1 parent 17ee725 commit cb726a7

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tests/test_custom_models.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4134,9 +4134,9 @@ def test_training_works(self, model_cls, custom_module_cls, custom_lora_cls):
41344134
config._register_custom_module({custom_module_cls: custom_lora_cls})
41354135

41364136
peft_model = get_peft_model(model, config)
4137-
sd_before = peft_model.state_dict()
4137+
sd_before = copy.deepcopy(peft_model.state_dict())
41384138
inputs = torch.randn(16, 10)
4139-
optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-1)
4139+
optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-4)
41404140

41414141
for _ in range(5):
41424142
optimizer.zero_grad()
@@ -4146,6 +4146,13 @@ def test_training_works(self, model_cls, custom_module_cls, custom_lora_cls):
41464146
optimizer.step()
41474147

41484148
sd_after = peft_model.state_dict()
4149+
4150+
# sanity check that for finite results, since nan != nan, which would make the test pass trivially
4151+
for val in sd_before.values():
4152+
assert torch.isfinite(val).all()
4153+
for val in sd_after.values():
4154+
assert torch.isfinite(val).all()
4155+
41494156
assert not torch.allclose(
41504157
sd_before["base_model.model.my_module.lora_A.default.weight"],
41514158
sd_after["base_model.model.my_module.lora_A.default.weight"],

0 commit comments

Comments
 (0)