Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/lightning/pytorch/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,10 @@ def sanitize_parameters_to_prune(

if not parameters_to_prune:
parameters_to_prune = [
(m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None
(m, p)
for p in parameters
for m in current_modules
if getattr(m, p, None) is not None and isinstance(getattr(m, p, None), nn.Parameter)
]
elif (
isinstance(parameters_to_prune, (list, tuple))
Expand Down
77 changes: 77 additions & 0 deletions tests/tests_pytorch/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,80 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
assert not hasattr(model.layer.mlp_3, "weight_orig")
model = TestModel.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")


def test_sanitize_parameters_explicit_check():
"""Test the sanitize_parameters_to_prune method with various attribute types."""

class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(5, 5))
self.bias = nn.Parameter(torch.randn(5))
self.some_bool = True
self.some_tensor = torch.randn(3, 3) # Regular tensor, not parameter
self.some_string = "test"
self.some_none = None

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.test_module = TestModule()

model = TestModel()

parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
model,
parameters_to_prune=(),
parameter_names=["weight", "bias", "some_bool", "some_tensor", "some_string", "some_none"],
)

param_names_found = set()
for module, param_name in parameters_to_prune:
param = getattr(module, param_name)
assert isinstance(param, nn.Parameter), f"Expected Parameter, got {type(param)}"
param_names_found.add(param_name)

assert "weight" in param_names_found
assert "bias" in param_names_found
assert "some_bool" not in param_names_found
assert "some_tensor" not in param_names_found
assert "some_string" not in param_names_found
assert "some_none" not in param_names_found


def test_original_issue_reproduction():
"""Issue: https://github.com/Lightning-AI/pytorch-lightning/issues/10835."""

class ProblematicModel(BoringModel):
def __init__(self):
super().__init__()
self.layer = Sequential(
OrderedDict([
("mlp_1", nn.Linear(32, 32)),
("mlp_2", nn.Linear(32, 2)),
])
)
# Add boolean attributes that would cause the original error
self.layer.mlp_1.training = True
self.layer.mlp_2.requires_grad = True

model = ProblematicModel()

try:
parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"]
)

for module, param_name in parameters_to_prune:
param = getattr(module, param_name)
assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"

success = True
except AttributeError as e:
if "'bool' object has no attribute 'is_cuda'" in str(e):
success = False # Original bug still present
else:
raise # Different error

assert success, "The fix for issue #10835 is not working correctly"
Loading