@@ -262,13 +262,13 @@ def test_multiple_pruning_callbacks(tmp_path, caplog, make_pruning_permanent: bo
262262 actual = [m for m in actual if m .startswith ("Applied" )]
263263 percentage = r"\(\d+(?:\.\d+)?%\)"
264264 expected = [
265- rf"Applied `L1Unstructured`. Pruned: \d+\/1122 { percentage } -> \d+\/1122 { percentage } " ,
265+ rf"Applied `L1Unstructured`. Pruned: \d+\/1088 { percentage } -> \d+\/1088 { percentage } " ,
266266 rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ { percentage } " , # noqa: E501
267267 rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ { percentage } " , # noqa: E501
268- rf"Applied `RandomUnstructured`. Pruned: \d+\/1122 { percentage } -> \d+\/1122 { percentage } " ,
268+ rf"Applied `RandomUnstructured`. Pruned: \d+\/1088 { percentage } -> \d+\/1088 { percentage } " ,
269269 rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.25. Pruned: \d+ { percentage } -> \d+ { percentage } " , # noqa: E501
270270 rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.25. Pruned: \d+ { percentage } -> \d+ { percentage } " , # noqa: E501
271- rf"Applied `L1Unstructured`. Pruned: \d+\/1122 { percentage } -> \d+\/1122 { percentage } " ,
271+ rf"Applied `L1Unstructured`. Pruned: \d+\/1088 { percentage } -> \d+\/1088 { percentage } " ,
272272 rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: \d+ { percentage } -> \d+ { percentage } " , # noqa: E501
273273 rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: \d+ { percentage } -> \d+ { percentage } " , # noqa: E501
274274 ]
@@ -329,9 +329,9 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
329329 actual = [m for m in actual if m .startswith ("Applied" )]
330330 percentage = r"\(\d+(?:\.\d+)?%\)"
331331 expected = [
332- rf"Applied `RandomUnstructured`. Pruned: \d+\/66 { percentage } -> \d+\/66 { percentage } " ,
333- rf"Applied `RandomUnstructured`. Pruned: \d+\/66 { percentage } -> \d+\/66 { percentage } " ,
334- rf"Applied `RandomUnstructured`. Pruned: \d+\/66 { percentage } -> \d+\/66 { percentage } " ,
332+ rf"Applied `RandomUnstructured`. Pruned: \d+\/64 { percentage } -> \d+\/64 { percentage } " ,
333+ rf"Applied `RandomUnstructured`. Pruned: \d+\/64 { percentage } -> \d+\/64 { percentage } " ,
334+ rf"Applied `RandomUnstructured`. Pruned: \d+\/64 { percentage } -> \d+\/64 { percentage } " ,
335335 ]
336336 expected = [re .compile (s ) for s in expected ]
337337 assert all (regex .match (s ) for s , regex in zip (actual , expected ))
@@ -463,3 +463,91 @@ def __init__(self):
463463 f"Actual weight_orig: { weight_orig } \n "
464464 f"Max difference: { torch .max (torch .abs (weight_orig - original_weights ))} "
465465 )
466+
467+
468+ @pytest .mark .parametrize ("pruning_amount" , [0.1 , 0.2 , 0.3 , 0.5 ])
469+ @pytest .mark .parametrize ("model_type" , ["simple" , "complex" ])
470+ def test_sparsity_calculation (tmp_path , caplog , pruning_amount : float , model_type : str ):
471+ """Test that the sparsity calculation fix correctly reports percentages."""
472+
473+ class SimpleModel (BoringModel ):
474+ """Simple model with 66 parameters (64 weight + 2 bias)."""
475+
476+ def __init__ (self ):
477+ super ().__init__ ()
478+ self .layer = nn .Linear (32 , 2 ) # 32*2 + 2 = 66 params
479+
480+ class ComplexModel (BoringModel ):
481+ """Complex model with multiple layers."""
482+
483+ def __init__ (self ):
484+ super ().__init__ ()
485+ self .layer1 = nn .Linear (32 , 64 ) # 32*64 + 64 = 2112 params
486+ self .layer2 = nn .Linear (64 , 2 ) # 64*2 + 2 = 130 params
487+ # Total: 2112 + 130 = 2242 params (but only layer1 will be pruned)
488+ # layer1 params: 2112
489+
490+ def forward (self , x ):
491+ x = torch .relu (self .layer1 (x ))
492+ return self .layer2 (x )
493+
494+ if model_type == "simple" :
495+ model = SimpleModel ()
496+ expected_total_params = 66
497+ parameters_to_prune = None
498+ else :
499+ model = ComplexModel ()
500+ expected_total_params = 2112
501+ parameters_to_prune = [(model .layer1 , "weight" ), (model .layer1 , "bias" )]
502+
503+ pruning = ModelPruning (
504+ pruning_fn = "l1_unstructured" ,
505+ parameters_to_prune = parameters_to_prune ,
506+ amount = pruning_amount ,
507+ verbose = 1 ,
508+ use_global_unstructured = True ,
509+ )
510+
511+ trainer = Trainer (
512+ default_root_dir = tmp_path ,
513+ enable_progress_bar = False ,
514+ enable_model_summary = False ,
515+ enable_checkpointing = False ,
516+ logger = False ,
517+ limit_train_batches = 1 ,
518+ max_epochs = 1 ,
519+ accelerator = "cpu" ,
520+ callbacks = [pruning ],
521+ )
522+
523+ with caplog .at_level (INFO ):
524+ trainer .fit (model )
525+
526+ sparsity_logs = [msg for msg in caplog .messages if "Applied `L1Unstructured`. Pruned:" in msg ]
527+ assert len (sparsity_logs ) == 1 , f"Expected 1 sparsity log, got { len (sparsity_logs )} "
528+ sparsity_log = sparsity_logs [0 ]
529+ pattern = r"Applied `L1Unstructured`\. Pruned: \d+/(\d+) \(\d+\.\d+%\) -> (\d+)/(\d+) \((\d+\.\d+)%\)"
530+ match = re .search (pattern , sparsity_log )
531+ assert match , f"Could not parse sparsity log: { sparsity_log } "
532+
533+ total_params_before = int (match .group (1 ))
534+ pruned_count = int (match .group (2 ))
535+ total_params_after = int (match .group (3 ))
536+ sparsity_percentage = float (match .group (4 ))
537+ assert total_params_before == expected_total_params , (
538+ f"Total parameter count mismatch for { model_type } model. "
539+ f"Expected { expected_total_params } , got { total_params_before } "
540+ )
541+ assert total_params_after == expected_total_params , (
542+ f"Total parameter count should be consistent. Before: { total_params_before } , After: { total_params_after } "
543+ )
544+
545+ # Verify sparsity percentage is approximately correct
546+ expected_sparsity = pruning_amount * 100
547+ tolerance = 5.0
548+ assert abs (sparsity_percentage - expected_sparsity ) <= tolerance
549+
550+ # Verify the number of pruned parameters is reasonable
551+ expected_pruned_count = int (expected_total_params * pruning_amount )
552+ pruned_tolerance = max (1 , int (expected_total_params * 0.05 ))
553+ assert abs (pruned_count - expected_pruned_count ) <= pruned_tolerance
0 commit comments