Skip to content

Commit 2682e09

Browse files
committed
Fix selective layer pruning tests and calibration validation
1 parent 55b5c02 commit 2682e09

File tree

5 files changed

+6
-5
lines changed

5 files changed

+6
-5
lines changed
151 Bytes
Binary file not shown.
2.49 KB
Binary file not shown.
1.07 KB
Binary file not shown.

optipfair/pruning/mlp_glu.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -629,15 +629,15 @@ def prune_model_mlp_glu(
629629
# Step 3: Extract accumulated norms
630630
activation_norms = get_activation_norms()
631631

632-
# Verify we collected norms for all layers
633-
num_layers = len(get_model_layers(model))
634-
if len(activation_norms) != num_layers:
632+
# Verify we collected norms for selected layers
633+
expected_layers = len(layer_indices) if layer_indices is not None else len(get_model_layers(model))
634+
if len(activation_norms) != expected_layers:
635635
raise RuntimeError(
636-
f"Calibration failed: expected norms for {num_layers} layers, "
636+
f"Calibration failed: expected norms for {expected_layers} layers, "
637637
f"got {len(activation_norms)}"
638638
)
639639

640-
logger.info(f"Calibration complete: collected activation norms for {num_layers} layers")
640+
logger.info(f"Calibration complete: collected activation norms for {expected_layers} layers")
641641

642642
finally:
643643
# Step 4: Always clean up hooks (even if error occurs)

tests/test_selective_layer_pruning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def test_selective_pruning_with_expansion_rate(model):
106106
pruned_model = prune_model(
107107
model=model,
108108
pruning_type="MLP_GLU",
109+
pruning_percentage=None, # Must be None when using expansion_rate
109110
expansion_rate=260, # Target 260% expansion rate
110111
layer_indices=layer_indices,
111112
show_progress=False

0 commit comments

Comments
 (0)