diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index a55c21a5c6ed1..93e828db0d24f 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -458,7 +458,10 @@ def sanitize_parameters_to_prune( if not parameters_to_prune: parameters_to_prune = [ - (m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None + (m, p) + for p in parameters + for m in current_modules + if getattr(m, p, None) is not None and isinstance(getattr(m, p, None), nn.Parameter) ] elif ( isinstance(parameters_to_prune, (list, tuple)) diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index d70ab68b78b32..6efe9b9992d00 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -338,3 +338,70 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): assert not hasattr(model.layer.mlp_3, "weight_orig") model = TestModel.load_from_checkpoint(trainer.checkpoint_callback.last_model_path) assert not hasattr(model.layer.mlp_3, "weight_orig") + + +def test_sanitize_parameters_explicit_check(): + """Test the sanitize_parameters_to_prune method with various attribute types.""" + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(5, 5)) + self.bias = nn.Parameter(torch.randn(5)) + self.some_bool = True + self.some_tensor = torch.randn(3, 3) # Regular tensor, not parameter + self.some_string = "test" + self.some_none = None + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.test_module = TestModule() + + model = TestModel() + + parameters_to_prune = ModelPruning.sanitize_parameters_to_prune( + model, + parameters_to_prune=(), + parameter_names=["weight", "bias", "some_bool", "some_tensor", "some_string", "some_none"], + ) + + param_names_found = set() + for module, param_name in parameters_to_prune: + param = getattr(module, param_name) + assert isinstance(param, nn.Parameter), f"Expected Parameter, got {type(param)}" + param_names_found.add(param_name) + + assert "weight" in param_names_found + assert "bias" in param_names_found + assert "some_bool" not in param_names_found + assert "some_tensor" not in param_names_found + assert "some_string" not in param_names_found + assert "some_none" not in param_names_found + + +def test_original_issue_reproduction(): + """Issue: https://github.com/Lightning-AI/pytorch-lightning/issues/10835.""" + + class ProblematicModel(BoringModel): + def __init__(self): + super().__init__() + self.layer = Sequential( + OrderedDict([ + ("mlp_1", nn.Linear(32, 32)), + ("mlp_2", nn.Linear(32, 2)), + ]) + ) + # Add boolean attributes that would cause the original error + self.layer.mlp_1.training = True + self.layer.mlp_2.requires_grad = True + + model = ProblematicModel() + + parameters_to_prune = ModelPruning.sanitize_parameters_to_prune( + model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"] + ) + + for module, param_name in parameters_to_prune: + param = getattr(module, param_name) + assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"