@@ -205,7 +205,12 @@ def apply_lottery_ticket_hypothesis(self):
205
205
for i , name in names :
206
206
curr , curr_name = self ._parameters_to_prune [i ]
207
207
assert name == curr_name
208
- actual , expected = getattr (curr , name ).data , getattr (copy , name ).data
208
+ # Check weight_orig if parameter is pruned, otherwise check the parameter directly
209
+ if hasattr (curr , name + "_orig" ):
210
+ actual = getattr (curr , name + "_orig" ).data
211
+ else :
212
+ actual = getattr (curr , name ).data
213
+ expected = getattr (copy , name ).data
209
214
allclose = torch .allclose (actual .cpu (), expected )
210
215
assert not allclose if self ._resample_parameters else allclose
211
216
@@ -405,3 +410,56 @@ def __init__(self):
405
410
for module , param_name in parameters_to_prune :
406
411
param = getattr (module , param_name )
407
412
assert isinstance (param , nn .Parameter ), f"Non-parameter found: { type (param )} "
413
+
414
+
415
+ def test_lottery_ticket_hypothesis_correctly_reset (tmp_path ):
416
+ """Test that lottery ticket hypothesis correctly resets unpruned weights to original values."""
417
+ seed_everything (42 )
418
+
419
+ class LTHTestModel (BoringModel ):
420
+ def __init__ (self ):
421
+ super ().__init__ ()
422
+ self .layer = nn .Linear (32 , 2 , bias = False )
423
+ with torch .no_grad ():
424
+ # Initialize with a simple pattern for verification
425
+ self .layer .weight .copy_ (torch .arange (1 , 65 , dtype = torch .float32 ).reshape (2 , 32 ))
426
+
427
+ model = LTHTestModel ()
428
+ original_weights = model .layer .weight .data .clone ()
429
+
430
+ # Create a pruning callback that applies both pruning and LTH at epoch 1
431
+ pruning_callback = ModelPruning (
432
+ "l1_unstructured" ,
433
+ parameters_to_prune = [(model .layer , "weight" )],
434
+ use_lottery_ticket_hypothesis = lambda epoch : epoch == 1 ,
435
+ amount = 0.5 ,
436
+ verbose = 0 , # Reduce verbosity
437
+ make_pruning_permanent = False ,
438
+ apply_pruning = lambda epoch : epoch == 1 ,
439
+ )
440
+
441
+ trainer = Trainer (
442
+ default_root_dir = tmp_path ,
443
+ enable_progress_bar = False ,
444
+ enable_model_summary = False ,
445
+ enable_checkpointing = False ,
446
+ logger = False ,
447
+ limit_train_batches = 5 ,
448
+ limit_val_batches = 1 ,
449
+ max_epochs = 2 ,
450
+ accelerator = "cpu" ,
451
+ callbacks = pruning_callback ,
452
+ )
453
+ trainer .fit (model )
454
+
455
+ # After training with LTH applied, check that weight_orig was reset correctly
456
+ assert hasattr (model .layer , "weight_mask" ), "Pruning should have created weight_mask"
457
+ assert hasattr (model .layer , "weight_orig" ), "Pruning should have created weight_orig"
458
+
459
+ weight_orig = getattr (model .layer , "weight_orig" )
460
+ assert torch .allclose (weight_orig , original_weights , atol = 1e-6 ), (
461
+ f"Lottery ticket hypothesis failed. weight_orig should be reset to original values.\n "
462
+ f"Expected weight_orig: { original_weights } \n "
463
+ f"Actual weight_orig: { weight_orig } \n "
464
+ f"Max difference: { torch .max (torch .abs (weight_orig - original_weights ))} "
465
+ )
0 commit comments