Skip to content

ModelPruning callback logs wrong sparsity #13595

@SungFeng-Huang

Description

@SungFeng-Huang

🐛 Bug

When using ModelPruning callback with verbosity level in {1,2}, the logged sparsity would be wrong when there are layers with multiple parameters to be pruned.
For example, when we add a ModelPruning(amount=0.2) callback to the BoringModel Colab link, we would get overall sparsity logged around 0.1 instead of 0.2:

Applied `L1Unstructured`. Pruned: 0/132 (0.00%) -> 13/132 (9.85%)
Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.2. Pruned: 0 (0.00%) -> 13 (20.31%)
Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).bias` with amount=0.2. Pruned: 0 (0.00%) -> 0 (0.00%)

The problem is mainly due to L346 in the following code, where there might have layers counted multiple times:
https://github.com/Lightning-AI/lightning/blob/b59f80224843886459d54c828325683d770da746/src/pytorch_lightning/callbacks/pruning.py#L343-L353
Can be fixed easily by:

# count from prev/curr should be the same
total_params = sum(params for _, params in prev)

To Reproduce

https://gist.github.com/SungFeng-Huang/52d676869ad4e8a4a00ac3e29437ecdd

Reproducible python code
import os

import torch
from torch.utils.data import DataLoader, Dataset

from lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


num_samples = 10000

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

from lightning.pytorch.callbacks import ModelPruning


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2, num_workers = 2, persistent_workers = True)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2, num_workers = 2, persistent_workers = True)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2, num_workers = 2, persistent_workers = True)

    pruning = ModelPruning(
        pruning_fn="l1_unstructured",
        parameters_to_prune=None,
        use_global_unstructured=True,
        amount=0.2,
        apply_pruning=True,
        use_lottery_ticket_hypothesis=True,
        resample_parameters=False,
        verbose=2,
        prune_on_train_epoch_end=True,
    )

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        callbacks=[pruning],
        enable_progress_bar=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

Expected behavior

Get sparsity logged around 0.2 instead of 0.1:

Applied `L1Unstructured`. Pruned: 0/132 (0.00%) -> 13/66 (19.70%)
Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.2. Pruned: 0 (0.00%) -> 13 (20.31%)
Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).bias` with amount=0.2. Pruned: 0 (0.00%) -> 0 (0.00%)

Environment

  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 11.3
  • Packages:
    • numpy: 1.21.6
    • pyTorch_debug: False
    • pyTorch_version: 1.11.0+cu113
    • pytorch-lightning: 1.6.4
    • tqdm: 4.64.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.13
    • version: Proposal for help #1 SMP Sun Apr 24 10:03:06 PDT 2022

Additional context

cc @tchaton @carmocca

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions