Skip to content

Commit db0b6fa

Browse files
authored
Merge pull request #2 from rittik9/patch-3
Update `test_pruning.py`
2 parents 0c783f3 + 2ddda83 commit db0b6fa

File tree

1 file changed

+6
-16
lines changed

1 file changed

+6
-16
lines changed

tests/tests_pytorch/callbacks/test_pruning.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -398,20 +398,10 @@ def __init__(self):
398398

399399
model = ProblematicModel()
400400

401-
try:
402-
parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
403-
model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"]
404-
)
405-
406-
for module, param_name in parameters_to_prune:
407-
param = getattr(module, param_name)
408-
assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"
409-
410-
success = True
411-
except AttributeError as e:
412-
if "'bool' object has no attribute 'is_cuda'" in str(e):
413-
success = False # Original bug still present
414-
else:
415-
raise # Different error
401+
parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
402+
model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"]
403+
)
416404

417-
assert success, "The fix for issue #10835 is not working correctly"
405+
for module, param_name in parameters_to_prune:
406+
param = getattr(module, param_name)
407+
assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"

0 commit comments

Comments
 (0)