Model Pruning (Lottery Ticket Hypothesis) Not Reinitializing Weights #13040
-
I am experimenting with model pruning in pytorch lightning. I noticed when pruning using the lottery ticket hypothesis (LTH), weights are not reset to the original initialization as proposed in the LTH paper and mentioned in the lightning docs. I reproduced the behaviour I faced in my own work using Tutorial 5: Transformers and Multi-Head Attention. Sharing this here to verify my analysis. Code I added: class CheckPruningWeight(pl.Callback):
def on_train_epoch_start(self, trainer, pl_module):
print(f"\nEPOCH {trainer.current_epoch} STARTING")
for name in trainer.model.state_dict():
if 'input_net.1' in name:
print(name, trainer.model.state_dict()[name][:1])
print('\n')
def on_train_epoch_end(self, trainer, pl_module):
print(f"\nEPOCH {trainer.current_epoch} ENDING")
for name in trainer.model.state_dict():
if 'input_net.1' in name:
print(name, trainer.model.state_dict()[name][:1]) Pruning callback: pruning_callback = pl.callbacks.ModelPruning(
pruning_fn="l1_unstructured",
amount=0.2,
use_global_unstructured=True,
use_lottery_ticket_hypothesis=True,
verbose=1,
parameter_names=['weight'],
resample_parameters=False,
prune_on_train_epoch_end=False,
) Trainer: trainer = pl.Trainer(
default_root_dir=root_dir,
callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
pruning_callback,
CheckPruningWeight()],
gpus=1 if str(device).startswith("cuda") else 0,
max_epochs=5,
gradient_clip_val=5,
progress_bar_refresh_rate=1,
) Training Snippet:
The training snippet shows the original weight tensor changing at the start of each epoch, changing values to that of the end of the previous epoch. By right, LTH should revert weights start of every epoch to the original initialization (start of epoch 0). Thoughts or correction is much appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
As I began implementing and comparing with my own custom pruning function, I came to the realization that weight_orig is the updated weight at the end of each epoch, not the original initialized weight. Printing the actual weight at the start of each epoch gives the original initialized weight. Had a different idea of what weight_orig meant, my bad! |
Beta Was this translation helpful? Give feedback.
As I began implementing and comparing with my own custom pruning function, I came to the realization that weight_orig is the updated weight at the end of each epoch, not the original initialized weight. Printing the actual weight at the start of each epoch gives the original initialized weight. Had a different idea of what weight_orig meant, my bad!