Skip to content

Commit ba5aef4

Browse files
authored
Update test_pruning.py
1 parent f9674f5 commit ba5aef4

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

tests/tests_pytorch/callbacks/test_pruning.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,80 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
338338
assert not hasattr(model.layer.mlp_3, "weight_orig")
339339
model = TestModel.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
340340
assert not hasattr(model.layer.mlp_3, "weight_orig")
341+
342+
343+
def test_sanitize_parameters_explicit_check():
344+
"""Test the sanitize_parameters_to_prune method with various attribute types."""
345+
346+
class TestModule(nn.Module):
347+
def __init__(self):
348+
super().__init__()
349+
self.weight = nn.Parameter(torch.randn(5, 5))
350+
self.bias = nn.Parameter(torch.randn(5))
351+
self.some_bool = True
352+
self.some_tensor = torch.randn(3, 3) # Regular tensor, not parameter
353+
self.some_string = "test"
354+
self.some_none = None
355+
356+
class TestModel(BoringModel):
357+
def __init__(self):
358+
super().__init__()
359+
self.test_module = TestModule()
360+
361+
model = TestModel()
362+
363+
parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
364+
model,
365+
parameters_to_prune=(),
366+
parameter_names=["weight", "bias", "some_bool", "some_tensor", "some_string", "some_none"],
367+
)
368+
369+
param_names_found = set()
370+
for module, param_name in parameters_to_prune:
371+
param = getattr(module, param_name)
372+
assert isinstance(param, nn.Parameter), f"Expected Parameter, got {type(param)}"
373+
param_names_found.add(param_name)
374+
375+
assert "weight" in param_names_found
376+
assert "bias" in param_names_found
377+
assert "some_bool" not in param_names_found
378+
assert "some_tensor" not in param_names_found
379+
assert "some_string" not in param_names_found
380+
assert "some_none" not in param_names_found
381+
382+
383+
def test_original_issue_reproduction():
384+
"""Issue: https://github.com/Lightning-AI/pytorch-lightning/issues/10835."""
385+
386+
class ProblematicModel(BoringModel):
387+
def __init__(self):
388+
super().__init__()
389+
self.layer = Sequential(
390+
OrderedDict([
391+
("mlp_1", nn.Linear(32, 32)),
392+
("mlp_2", nn.Linear(32, 2)),
393+
])
394+
)
395+
# Add boolean attributes that would cause the original error
396+
self.layer.mlp_1.training = True
397+
self.layer.mlp_2.requires_grad = True
398+
399+
model = ProblematicModel()
400+
401+
try:
402+
parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
403+
model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"]
404+
)
405+
406+
for module, param_name in parameters_to_prune:
407+
param = getattr(module, param_name)
408+
assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"
409+
410+
success = True
411+
except AttributeError as e:
412+
if "'bool' object has no attribute 'is_cuda'" in str(e):
413+
success = False # Original bug still present
414+
else:
415+
raise # Different error
416+
417+
assert success, "The fix for issue #10835 is not working correctly"

0 commit comments

Comments
 (0)