@@ -262,13 +262,13 @@ def test_multiple_pruning_callbacks(tmp_path, caplog, make_pruning_permanent: bo
262
262
actual = [m for m in actual if m .startswith ("Applied" )]
263
263
percentage = r"\(\d+(?:\.\d+)?%\)"
264
264
expected = [
265
- rf"Applied `L1Unstructured`. Pruned: \d+\/1122 { percentage } -> \d+\/1122 { percentage } " ,
265
+ rf"Applied `L1Unstructured`. Pruned: \d+\/1088 { percentage } -> \d+\/1088 { percentage } " ,
266
266
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ { percentage } " , # noqa: E501
267
267
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ { percentage } " , # noqa: E501
268
- rf"Applied `RandomUnstructured`. Pruned: \d+\/1122 { percentage } -> \d+\/1122 { percentage } " ,
268
+ rf"Applied `RandomUnstructured`. Pruned: \d+\/1088 { percentage } -> \d+\/1088 { percentage } " ,
269
269
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.25. Pruned: \d+ { percentage } -> \d+ { percentage } " , # noqa: E501
270
270
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.25. Pruned: \d+ { percentage } -> \d+ { percentage } " , # noqa: E501
271
- rf"Applied `L1Unstructured`. Pruned: \d+\/1122 { percentage } -> \d+\/1122 { percentage } " ,
271
+ rf"Applied `L1Unstructured`. Pruned: \d+\/1088 { percentage } -> \d+\/1088 { percentage } " ,
272
272
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: \d+ { percentage } -> \d+ { percentage } " , # noqa: E501
273
273
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: \d+ { percentage } -> \d+ { percentage } " , # noqa: E501
274
274
]
@@ -329,9 +329,9 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
329
329
actual = [m for m in actual if m .startswith ("Applied" )]
330
330
percentage = r"\(\d+(?:\.\d+)?%\)"
331
331
expected = [
332
- rf"Applied `RandomUnstructured`. Pruned: \d+\/66 { percentage } -> \d+\/66 { percentage } " ,
333
- rf"Applied `RandomUnstructured`. Pruned: \d+\/66 { percentage } -> \d+\/66 { percentage } " ,
334
- rf"Applied `RandomUnstructured`. Pruned: \d+\/66 { percentage } -> \d+\/66 { percentage } " ,
332
+ rf"Applied `RandomUnstructured`. Pruned: \d+\/64 { percentage } -> \d+\/64 { percentage } " ,
333
+ rf"Applied `RandomUnstructured`. Pruned: \d+\/64 { percentage } -> \d+\/64 { percentage } " ,
334
+ rf"Applied `RandomUnstructured`. Pruned: \d+\/64 { percentage } -> \d+\/64 { percentage } " ,
335
335
]
336
336
expected = [re .compile (s ) for s in expected ]
337
337
assert all (regex .match (s ) for s , regex in zip (actual , expected ))
@@ -463,3 +463,91 @@ def __init__(self):
463
463
f"Actual weight_orig: { weight_orig } \n "
464
464
f"Max difference: { torch .max (torch .abs (weight_orig - original_weights ))} "
465
465
)
466
+
467
+
468
+ @pytest .mark .parametrize ("pruning_amount" , [0.1 , 0.2 , 0.3 , 0.5 ])
469
+ @pytest .mark .parametrize ("model_type" , ["simple" , "complex" ])
470
+ def test_sparsity_calculation (tmp_path , caplog , pruning_amount : float , model_type : str ):
471
+ """Test that the sparsity calculation fix correctly reports percentages."""
472
+
473
+ class SimpleModel (BoringModel ):
474
+ """Simple model with 66 parameters (64 weight + 2 bias)."""
475
+
476
+ def __init__ (self ):
477
+ super ().__init__ ()
478
+ self .layer = nn .Linear (32 , 2 ) # 32*2 + 2 = 66 params
479
+
480
+ class ComplexModel (BoringModel ):
481
+ """Complex model with multiple layers."""
482
+
483
+ def __init__ (self ):
484
+ super ().__init__ ()
485
+ self .layer1 = nn .Linear (32 , 64 ) # 32*64 + 64 = 2112 params
486
+ self .layer2 = nn .Linear (64 , 2 ) # 64*2 + 2 = 130 params
487
+ # Total: 2112 + 130 = 2242 params (but only layer1 will be pruned)
488
+ # layer1 params: 2112
489
+
490
+ def forward (self , x ):
491
+ x = torch .relu (self .layer1 (x ))
492
+ return self .layer2 (x )
493
+
494
+ if model_type == "simple" :
495
+ model = SimpleModel ()
496
+ expected_total_params = 66
497
+ parameters_to_prune = None
498
+ else :
499
+ model = ComplexModel ()
500
+ expected_total_params = 2112
501
+ parameters_to_prune = [(model .layer1 , "weight" ), (model .layer1 , "bias" )]
502
+
503
+ pruning = ModelPruning (
504
+ pruning_fn = "l1_unstructured" ,
505
+ parameters_to_prune = parameters_to_prune ,
506
+ amount = pruning_amount ,
507
+ verbose = 1 ,
508
+ use_global_unstructured = True ,
509
+ )
510
+
511
+ trainer = Trainer (
512
+ default_root_dir = tmp_path ,
513
+ enable_progress_bar = False ,
514
+ enable_model_summary = False ,
515
+ enable_checkpointing = False ,
516
+ logger = False ,
517
+ limit_train_batches = 1 ,
518
+ max_epochs = 1 ,
519
+ accelerator = "cpu" ,
520
+ callbacks = [pruning ],
521
+ )
522
+
523
+ with caplog .at_level (INFO ):
524
+ trainer .fit (model )
525
+
526
+ sparsity_logs = [msg for msg in caplog .messages if "Applied `L1Unstructured`. Pruned:" in msg ]
527
+ assert len (sparsity_logs ) == 1 , f"Expected 1 sparsity log, got { len (sparsity_logs )} "
528
+ sparsity_log = sparsity_logs [0 ]
529
+ pattern = r"Applied `L1Unstructured`\. Pruned: \d+/(\d+) \(\d+\.\d+%\) -> (\d+)/(\d+) \((\d+\.\d+)%\)"
530
+ match = re .search (pattern , sparsity_log )
531
+ assert match , f"Could not parse sparsity log: { sparsity_log } "
532
+
533
+ total_params_before = int (match .group (1 ))
534
+ pruned_count = int (match .group (2 ))
535
+ total_params_after = int (match .group (3 ))
536
+ sparsity_percentage = float (match .group (4 ))
537
+ assert total_params_before == expected_total_params , (
538
+ f"Total parameter count mismatch for { model_type } model. "
539
+ f"Expected { expected_total_params } , got { total_params_before } "
540
+ )
541
+ assert total_params_after == expected_total_params , (
542
+ f"Total parameter count should be consistent. Before: { total_params_before } , After: { total_params_after } "
543
+ )
544
+
545
+ # Verify sparsity percentage is approximately correct
546
+ expected_sparsity = pruning_amount * 100
547
+ tolerance = 5.0
548
+ assert abs (sparsity_percentage - expected_sparsity ) <= tolerance
549
+
550
+ # Verify the number of pruned parameters is reasonable
551
+ expected_pruned_count = int (expected_total_params * pruning_amount )
552
+ pruned_tolerance = max (1 , int (expected_total_params * 0.05 ))
553
+ assert abs (pruned_count - expected_pruned_count ) <= pruned_tolerance
0 commit comments