Skip to content

Weights are misshappen when using model's forward in on_fit_end() hook with FSDP #20255

@QuentinAndre11

Description

@QuentinAndre11

Bug description

Hi everyone !
I am training an image classifier and would like to see the embeddings at the end of training, but I don't find how to do it while using FSDP, since the weights seem to get flattenned outside of train/validation/_step. Indeed, with the following code, I get a RuntimeError: weight should have at least three dimensions.
Is it an attended behaviour? I don't understand because I tried to run in the predict step instead of the on_predict_end hook and I had the same error.

Thanks for your help already

What version are you seeing the problem on?

v2.3

How to reproduce the bug

import certifi
import os
import timm

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import lightning as L

from lightning.pytorch.core import LightningModule, LightningDataModule
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.strategies import FSDPStrategy

from torchvision.datasets import MNIST
from torchvision.transforms import v2
from torchvision.transforms import Lambda

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from einops import rearrange



class ResnetModule(LightningModule):
    def __init__(self, num_classes=10, in_chans=3):
        super().__init__()
        self.model = timm.create_model("resnet50", pretrained = False, num_classes = num_classes, in_chans = in_chans, drop_rate = 0.3)
        self.loss_fn = nn.BCEWithLogitsLoss() 

    def forward(self, x):
        out = self.model(x)
        return out

    def training_step(self, batch, batch_idx):
        # Here we have self.model.conv1.weight.shape = torch.Size([64, 3, 7, 7])
        loss = self._calculate_loss(batch, batch_idx, "train")
        return loss
    
    def validation_step(self, batch, batch_idx):
        # Here we have self.model.conv1.weight.shape = torch.Size([64, 3, 7, 7])
        self._calculate_loss(batch, batch_idx, "val")
        
    def _calculate_loss(self, batch, batch_idx, mode = "train"):
        images, labels = batch
        
        outputs = self(images)
        loss = self.loss_fn(outputs, labels)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), 1e-4)
        return optimizer
    
    def get_tb_logger(self, experiment:bool = False) -> pl_loggers.TensorBoardLogger | SummaryWriter:
        for lg in self.trainer.loggers:
            if isinstance(lg, pl_loggers.TensorBoardLogger):
                return lg.experiment if experiment else lg
        return None
    
    def on_train_end(self):
        # !!!!!!!!!
        # Here we have self.model.conv1.weight.shape = torch.Size([9408])
        # !!!!!!!!!
        pass
    
    def on_fit_end(self):
        # !!!!!!!!!
        # Here we have self.model.conv1.weight.shape = torch.Size([9408]) 
        # !!!!!!!!!
        embeddings_activations = []
        embeddings_inputs = []
        embeddings_labels = []
        
        def hook(model, input, output):
            embeddings_activations.append(output)

        # Attach hook to the wanted layer
        hook_handle = self.model.global_pool.register_forward_hook(hook) 
        
        val_dataloader = self.trainer.datamodule.val_dataloader()
        for batch_idx, batch in enumerate(val_dataloader):
            imgs, labels = batch
            embeddings_inputs.append(imgs)
            embeddings_labels.append(labels)
            self(imgs)
        tb_logger = self.get_tb_logger(experiment=True)
        if tb_logger:

            features = rearrange(embeddings_activations, 'n b p -> (n b) p')
            images = rearrange(embeddings_inputs, 'n b c h w -> (n b) c h w')
            labels_one_hot = rearrange(embeddings_labels, 'n b l -> (n b) l')
            metadata = [torch.argmax(t).item() for t in labels_one_hot]
            tb_logger.add_embedding(
                features,
                metadata = metadata,
                label_img = images,
                global_step = self.current_epoch,
                tag = f"{self.model.__class__.__name__}'s embeddings",
            )
            
        hook_handle.remove()
        del embeddings_activations
        del embeddings_inputs
        del embeddings_labels


class MNISTModule(LightningDataModule):
    def __init__(self, num_workers, pin_memory):
        super().__init__()
        self.batch_size = 64
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.transform = transforms.Compose(
            [
            v2.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
            )
        self.mnist_path = os.path.join(os.getcwd(),'dataset','mnist')
        
    def prepare_data(self):
        dataset = MNIST(
            self.mnist_path,
            download=True,
            transform=self.transform if self.transform is not None else transforms.ToTensor(),
            target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
        )
        
        
    def setup(self, stage: str):
        if stage == "fit":
            self.train_set = MNIST(
                self.mnist_path,
                download=False,
                transform=self.transform if self.transform is not None else transforms.ToTensor(),
                target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
            )
            self.val_set = MNIST(
                self.mnist_path,
                download=False,
                transform=self.transform if self.transform is not None else transforms.ToTensor(),
                target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
            )


    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size,num_workers=self.num_workers,pin_memory=self.pin_memory)
    
    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size,num_workers=self.num_workers,pin_memory=self.pin_memory)

def main():
    print("Using PyTorch {} and Lightning {}".format(torch.__version__, L.__version__))

    # Lightning modules    
    datamodule = MNISTModule(num_workers=8, pin_memory=True)
    resnet = ResnetModule(num_classes=10, in_chans=3)
    
    # Logger
    tensorboard = pl_loggers.TensorBoardLogger(
        save_dir = os.path.join(os.getcwd(), 'results'),
        log_graph = False,
        )
    
    # Strategy
    strategy = FSDPStrategy(
            activation_checkpointing_policy={nn.Linear,nn.Conv2d},  
            sharding_strategy="FULL_SHARD",  
        )

    # Trainer
    trainer = Trainer(
        devices=2,
        max_epochs=2,
        strategy=strategy,
        logger=tensorboard,
        )

    trainer.fit(resnet, datamodule)
    trainer.print(torch.cuda.memory_summary())

if __name__ == '__main__':
    main()

Error messages and logs

Traceback (most recent call last):
  File "/shared/nfs/apps/python/gcc/12.1.0/python-3.10.5/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/shared/nfs/apps/python/gcc/12.1.0/python-3.10.5/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/shared/nfs/home/my_project/minimal_working_example.py", line 217, in <module>
Traceback (most recent call last):
  File "/shared/nfs/apps/python/gcc/12.1.0/python-3.10.5/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    main()
  File "/shared/nfs/home/my_project/minimal_working_example.py", line 193, in main
    return _run_code(code, main_globals, None,
  File "/shared/nfs/apps/python/gcc/12.1.0/python-3.10.5/lib/python3.10/runpy.py", line 86, in _run_code
    trainer.fit(resnet, datamodule)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
    exec(code, run_globals)
  File "/shared/nfs/home/my_project/minimal_working_example.py", line 217, in <module>
    main()
  File "/shared/nfs/home/my_project/minimal_working_example.py", line 193, in main
    call._call_and_handle_interrupt(
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    trainer.fit(resnet, datamodule)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
    call._call_and_handle_interrupt(
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return function(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
    return trainer_fn(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 996, in _run
    self._run(model, ckpt_path=ckpt_path)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 996, in _run
    call._call_lightning_module_hook(self, "on_fit_end")
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 159, in _call_lightning_module_hook
    call._call_lightning_module_hook(self, "on_fit_end")
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 159, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/shared/nfs/home/my_project/minimal_working_example.py", line 92, in on_fit_end
    output = fn(*args, **kwargs)
  File "/shared/nfs/home/my_project/minimal_working_example.py", line 92, in on_fit_end
    self(imgs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    self(imgs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return self._call_impl(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/shared/nfs/home/my_project/minimal_working_example.py", line 44, in forward
    return forward_call(*args, **kwargs)
  File "/shared/nfs/home/my_project/minimal_working_example.py", line 44, in forward
    out = self.model(x)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    out = self.model(x)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return self._call_impl(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/timm/models/resnet.py", line 635, in forward
    return forward_call(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/timm/models/resnet.py", line 635, in forward
    x = self.forward_features(x)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/timm/models/resnet.py", line 614, in forward_features
    x = self.forward_features(x)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/timm/models/resnet.py", line 614, in forward_features
    x = self.conv1(x)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    x = self.conv1(x)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 164, in forward
    return self._call_impl(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return self.checkpoint_fn(  # type: ignore[misc]
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return forward_call(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 164, in forward
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return self.checkpoint_fn(  # type: ignore[misc]
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return fn(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 458, in checkpoint
    return fn(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    ret = function(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return fn(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 458, in checkpoint
    ret = function(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._call_impl(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return self._conv_forward(input, self.weight, self.bias)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: weight should have at least three dimensions
    return forward_call(*args, **kwargs)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: weight should have at least three dimensions

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0): 2.3.3
#- PyTorch Version (e.g., 2.4): 2.1.0 
#- Python version (e.g., 3.12): 3.10.5
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 118 
#- GPU models and configuration: 2 TeslaV100
#- How you installed Lightning(`conda`, `pip`, source): pip

More info

No response

cc @lantiga

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions