Skip to content

Commit eab17c2

Browse files
committed
add testing
1 parent 2bba395 commit eab17c2

File tree

1 file changed

+94
-6
lines changed

1 file changed

+94
-6
lines changed

tests/tests_pytorch/callbacks/test_pruning.py

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,13 +262,13 @@ def test_multiple_pruning_callbacks(tmp_path, caplog, make_pruning_permanent: bo
262262
actual = [m for m in actual if m.startswith("Applied")]
263263
percentage = r"\(\d+(?:\.\d+)?%\)"
264264
expected = [
265-
rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
265+
rf"Applied `L1Unstructured`. Pruned: \d+\/1088 {percentage} -> \d+\/1088 {percentage}",
266266
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501
267267
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501
268-
rf"Applied `RandomUnstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
268+
rf"Applied `RandomUnstructured`. Pruned: \d+\/1088 {percentage} -> \d+\/1088 {percentage}",
269269
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
270270
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
271-
rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
271+
rf"Applied `L1Unstructured`. Pruned: \d+\/1088 {percentage} -> \d+\/1088 {percentage}",
272272
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
273273
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
274274
]
@@ -329,9 +329,9 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
329329
actual = [m for m in actual if m.startswith("Applied")]
330330
percentage = r"\(\d+(?:\.\d+)?%\)"
331331
expected = [
332-
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
333-
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
334-
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
332+
rf"Applied `RandomUnstructured`. Pruned: \d+\/64 {percentage} -> \d+\/64 {percentage}",
333+
rf"Applied `RandomUnstructured`. Pruned: \d+\/64 {percentage} -> \d+\/64 {percentage}",
334+
rf"Applied `RandomUnstructured`. Pruned: \d+\/64 {percentage} -> \d+\/64 {percentage}",
335335
]
336336
expected = [re.compile(s) for s in expected]
337337
assert all(regex.match(s) for s, regex in zip(actual, expected))
@@ -463,3 +463,91 @@ def __init__(self):
463463
f"Actual weight_orig: {weight_orig}\n"
464464
f"Max difference: {torch.max(torch.abs(weight_orig - original_weights))}"
465465
)
466+
467+
468+
@pytest.mark.parametrize("pruning_amount", [0.1, 0.2, 0.3, 0.5])
469+
@pytest.mark.parametrize("model_type", ["simple", "complex"])
470+
def test_sparsity_calculation(tmp_path, caplog, pruning_amount: float, model_type: str):
471+
"""Test that the sparsity calculation fix correctly reports percentages."""
472+
473+
class SimpleModel(BoringModel):
474+
"""Simple model with 66 parameters (64 weight + 2 bias)."""
475+
476+
def __init__(self):
477+
super().__init__()
478+
self.layer = nn.Linear(32, 2) # 32*2 + 2 = 66 params
479+
480+
class ComplexModel(BoringModel):
481+
"""Complex model with multiple layers."""
482+
483+
def __init__(self):
484+
super().__init__()
485+
self.layer1 = nn.Linear(32, 64) # 32*64 + 64 = 2112 params
486+
self.layer2 = nn.Linear(64, 2) # 64*2 + 2 = 130 params
487+
# Total: 2112 + 130 = 2242 params (but only layer1 will be pruned)
488+
# layer1 params: 2112
489+
490+
def forward(self, x):
491+
x = torch.relu(self.layer1(x))
492+
return self.layer2(x)
493+
494+
if model_type == "simple":
495+
model = SimpleModel()
496+
expected_total_params = 66
497+
parameters_to_prune = None
498+
else:
499+
model = ComplexModel()
500+
expected_total_params = 2112
501+
parameters_to_prune = [(model.layer1, "weight"), (model.layer1, "bias")]
502+
503+
pruning = ModelPruning(
504+
pruning_fn="l1_unstructured",
505+
parameters_to_prune=parameters_to_prune,
506+
amount=pruning_amount,
507+
verbose=1,
508+
use_global_unstructured=True,
509+
)
510+
511+
trainer = Trainer(
512+
default_root_dir=tmp_path,
513+
enable_progress_bar=False,
514+
enable_model_summary=False,
515+
enable_checkpointing=False,
516+
logger=False,
517+
limit_train_batches=1,
518+
max_epochs=1,
519+
accelerator="cpu",
520+
callbacks=[pruning],
521+
)
522+
523+
with caplog.at_level(INFO):
524+
trainer.fit(model)
525+
526+
sparsity_logs = [msg for msg in caplog.messages if "Applied `L1Unstructured`. Pruned:" in msg]
527+
assert len(sparsity_logs) == 1, f"Expected 1 sparsity log, got {len(sparsity_logs)}"
528+
sparsity_log = sparsity_logs[0]
529+
pattern = r"Applied `L1Unstructured`\. Pruned: \d+/(\d+) \(\d+\.\d+%\) -> (\d+)/(\d+) \((\d+\.\d+)%\)"
530+
match = re.search(pattern, sparsity_log)
531+
assert match, f"Could not parse sparsity log: {sparsity_log}"
532+
533+
total_params_before = int(match.group(1))
534+
pruned_count = int(match.group(2))
535+
total_params_after = int(match.group(3))
536+
sparsity_percentage = float(match.group(4))
537+
assert total_params_before == expected_total_params, (
538+
f"Total parameter count mismatch for {model_type} model. "
539+
f"Expected {expected_total_params}, got {total_params_before}"
540+
)
541+
assert total_params_after == expected_total_params, (
542+
f"Total parameter count should be consistent. Before: {total_params_before}, After: {total_params_after}"
543+
)
544+
545+
# Verify sparsity percentage is approximately correct
546+
expected_sparsity = pruning_amount * 100
547+
tolerance = 5.0
548+
assert abs(sparsity_percentage - expected_sparsity) <= tolerance
549+
550+
# Verify the number of pruned parameters is reasonable
551+
expected_pruned_count = int(expected_total_params * pruning_amount)
552+
pruned_tolerance = max(1, int(expected_total_params * 0.05))
553+
assert abs(pruned_count - expected_pruned_count) <= pruned_tolerance

0 commit comments

Comments
 (0)