@@ -338,3 +338,80 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
338338 assert not hasattr (model .layer .mlp_3 , "weight_orig" )
339339 model = TestModel .load_from_checkpoint (trainer .checkpoint_callback .last_model_path )
340340 assert not hasattr (model .layer .mlp_3 , "weight_orig" )
341+
342+
343+ def test_sanitize_parameters_explicit_check ():
344+ """Test the sanitize_parameters_to_prune method with various attribute types."""
345+
346+ class TestModule (nn .Module ):
347+ def __init__ (self ):
348+ super ().__init__ ()
349+ self .weight = nn .Parameter (torch .randn (5 , 5 ))
350+ self .bias = nn .Parameter (torch .randn (5 ))
351+ self .some_bool = True
352+ self .some_tensor = torch .randn (3 , 3 ) # Regular tensor, not parameter
353+ self .some_string = "test"
354+ self .some_none = None
355+
356+ class TestModel (BoringModel ):
357+ def __init__ (self ):
358+ super ().__init__ ()
359+ self .test_module = TestModule ()
360+
361+ model = TestModel ()
362+
363+ parameters_to_prune = ModelPruning .sanitize_parameters_to_prune (
364+ model ,
365+ parameters_to_prune = (),
366+ parameter_names = ["weight" , "bias" , "some_bool" , "some_tensor" , "some_string" , "some_none" ],
367+ )
368+
369+ param_names_found = set ()
370+ for module , param_name in parameters_to_prune :
371+ param = getattr (module , param_name )
372+ assert isinstance (param , nn .Parameter ), f"Expected Parameter, got { type (param )} "
373+ param_names_found .add (param_name )
374+
375+ assert "weight" in param_names_found
376+ assert "bias" in param_names_found
377+ assert "some_bool" not in param_names_found
378+ assert "some_tensor" not in param_names_found
379+ assert "some_string" not in param_names_found
380+ assert "some_none" not in param_names_found
381+
382+
383+ def test_original_issue_reproduction ():
384+ """Issue: https://github.com/Lightning-AI/pytorch-lightning/issues/10835."""
385+
386+ class ProblematicModel (BoringModel ):
387+ def __init__ (self ):
388+ super ().__init__ ()
389+ self .layer = Sequential (
390+ OrderedDict ([
391+ ("mlp_1" , nn .Linear (32 , 32 )),
392+ ("mlp_2" , nn .Linear (32 , 2 )),
393+ ])
394+ )
395+ # Add boolean attributes that would cause the original error
396+ self .layer .mlp_1 .training = True
397+ self .layer .mlp_2 .requires_grad = True
398+
399+ model = ProblematicModel ()
400+
401+ try :
402+ parameters_to_prune = ModelPruning .sanitize_parameters_to_prune (
403+ model , parameters_to_prune = (), parameter_names = ["weight" , "bias" , "training" , "requires_grad" ]
404+ )
405+
406+ for module , param_name in parameters_to_prune :
407+ param = getattr (module , param_name )
408+ assert isinstance (param , nn .Parameter ), f"Non-parameter found: { type (param )} "
409+
410+ success = True
411+ except AttributeError as e :
412+ if "'bool' object has no attribute 'is_cuda'" in str (e ):
413+ success = False # Original bug still present
414+ else :
415+ raise # Different error
416+
417+ assert success , "The fix for issue #10835 is not working correctly"
0 commit comments