-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingstrategy: fsdpFully Sharded Data ParallelFully Sharded Data Parallelver: 2.3.x
Description
Bug description
Hi everyone !
I am training an image classifier and would like to see the embeddings at the end of training, but I don't find how to do it while using FSDP, since the weights seem to get flattenned outside of train/validation/_step. Indeed, with the following code, I get a RuntimeError: weight should have at least three dimensions.
Is it an attended behaviour? I don't understand because I tried to run in the predict step instead of the on_predict_end hook and I had the same error.
Thanks for your help already
What version are you seeing the problem on?
v2.3
How to reproduce the bug
import certifi
import os
import timm
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import lightning as L
from lightning.pytorch.core import LightningModule, LightningDataModule
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.strategies import FSDPStrategy
from torchvision.datasets import MNIST
from torchvision.transforms import v2
from torchvision.transforms import Lambda
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from einops import rearrange
class ResnetModule(LightningModule):
def __init__(self, num_classes=10, in_chans=3):
super().__init__()
self.model = timm.create_model("resnet50", pretrained = False, num_classes = num_classes, in_chans = in_chans, drop_rate = 0.3)
self.loss_fn = nn.BCEWithLogitsLoss()
def forward(self, x):
out = self.model(x)
return out
def training_step(self, batch, batch_idx):
# Here we have self.model.conv1.weight.shape = torch.Size([64, 3, 7, 7])
loss = self._calculate_loss(batch, batch_idx, "train")
return loss
def validation_step(self, batch, batch_idx):
# Here we have self.model.conv1.weight.shape = torch.Size([64, 3, 7, 7])
self._calculate_loss(batch, batch_idx, "val")
def _calculate_loss(self, batch, batch_idx, mode = "train"):
images, labels = batch
outputs = self(images)
loss = self.loss_fn(outputs, labels)
return loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), 1e-4)
return optimizer
def get_tb_logger(self, experiment:bool = False) -> pl_loggers.TensorBoardLogger | SummaryWriter:
for lg in self.trainer.loggers:
if isinstance(lg, pl_loggers.TensorBoardLogger):
return lg.experiment if experiment else lg
return None
def on_train_end(self):
# !!!!!!!!!
# Here we have self.model.conv1.weight.shape = torch.Size([9408])
# !!!!!!!!!
pass
def on_fit_end(self):
# !!!!!!!!!
# Here we have self.model.conv1.weight.shape = torch.Size([9408])
# !!!!!!!!!
embeddings_activations = []
embeddings_inputs = []
embeddings_labels = []
def hook(model, input, output):
embeddings_activations.append(output)
# Attach hook to the wanted layer
hook_handle = self.model.global_pool.register_forward_hook(hook)
val_dataloader = self.trainer.datamodule.val_dataloader()
for batch_idx, batch in enumerate(val_dataloader):
imgs, labels = batch
embeddings_inputs.append(imgs)
embeddings_labels.append(labels)
self(imgs)
tb_logger = self.get_tb_logger(experiment=True)
if tb_logger:
features = rearrange(embeddings_activations, 'n b p -> (n b) p')
images = rearrange(embeddings_inputs, 'n b c h w -> (n b) c h w')
labels_one_hot = rearrange(embeddings_labels, 'n b l -> (n b) l')
metadata = [torch.argmax(t).item() for t in labels_one_hot]
tb_logger.add_embedding(
features,
metadata = metadata,
label_img = images,
global_step = self.current_epoch,
tag = f"{self.model.__class__.__name__}'s embeddings",
)
hook_handle.remove()
del embeddings_activations
del embeddings_inputs
del embeddings_labels
class MNISTModule(LightningDataModule):
def __init__(self, num_workers, pin_memory):
super().__init__()
self.batch_size = 64
self.num_workers = num_workers
self.pin_memory = pin_memory
self.transform = transforms.Compose(
[
v2.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
self.mnist_path = os.path.join(os.getcwd(),'dataset','mnist')
def prepare_data(self):
dataset = MNIST(
self.mnist_path,
download=True,
transform=self.transform if self.transform is not None else transforms.ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
def setup(self, stage: str):
if stage == "fit":
self.train_set = MNIST(
self.mnist_path,
download=False,
transform=self.transform if self.transform is not None else transforms.ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
self.val_set = MNIST(
self.mnist_path,
download=False,
transform=self.transform if self.transform is not None else transforms.ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
def train_dataloader(self):
return DataLoader(self.train_set, batch_size=self.batch_size,num_workers=self.num_workers,pin_memory=self.pin_memory)
def val_dataloader(self):
return DataLoader(self.val_set, batch_size=self.batch_size,num_workers=self.num_workers,pin_memory=self.pin_memory)
def main():
print("Using PyTorch {} and Lightning {}".format(torch.__version__, L.__version__))
# Lightning modules
datamodule = MNISTModule(num_workers=8, pin_memory=True)
resnet = ResnetModule(num_classes=10, in_chans=3)
# Logger
tensorboard = pl_loggers.TensorBoardLogger(
save_dir = os.path.join(os.getcwd(), 'results'),
log_graph = False,
)
# Strategy
strategy = FSDPStrategy(
activation_checkpointing_policy={nn.Linear,nn.Conv2d},
sharding_strategy="FULL_SHARD",
)
# Trainer
trainer = Trainer(
devices=2,
max_epochs=2,
strategy=strategy,
logger=tensorboard,
)
trainer.fit(resnet, datamodule)
trainer.print(torch.cuda.memory_summary())
if __name__ == '__main__':
main()
Error messages and logs
Traceback (most recent call last):
File "/shared/nfs/apps/python/gcc/12.1.0/python-3.10.5/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/shared/nfs/apps/python/gcc/12.1.0/python-3.10.5/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/shared/nfs/home/my_project/minimal_working_example.py", line 217, in <module>
Traceback (most recent call last):
File "/shared/nfs/apps/python/gcc/12.1.0/python-3.10.5/lib/python3.10/runpy.py", line 196, in _run_module_as_main
main()
File "/shared/nfs/home/my_project/minimal_working_example.py", line 193, in main
return _run_code(code, main_globals, None,
File "/shared/nfs/apps/python/gcc/12.1.0/python-3.10.5/lib/python3.10/runpy.py", line 86, in _run_code
trainer.fit(resnet, datamodule)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
exec(code, run_globals)
File "/shared/nfs/home/my_project/minimal_working_example.py", line 217, in <module>
main()
File "/shared/nfs/home/my_project/minimal_working_example.py", line 193, in main
call._call_and_handle_interrupt(
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
trainer.fit(resnet, datamodule)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
call._call_and_handle_interrupt(
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
return function(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
return trainer_fn(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 996, in _run
self._run(model, ckpt_path=ckpt_path)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 996, in _run
call._call_lightning_module_hook(self, "on_fit_end")
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 159, in _call_lightning_module_hook
call._call_lightning_module_hook(self, "on_fit_end")
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 159, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/shared/nfs/home/my_project/minimal_working_example.py", line 92, in on_fit_end
output = fn(*args, **kwargs)
File "/shared/nfs/home/my_project/minimal_working_example.py", line 92, in on_fit_end
self(imgs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
self(imgs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/shared/nfs/home/my_project/minimal_working_example.py", line 44, in forward
return forward_call(*args, **kwargs)
File "/shared/nfs/home/my_project/minimal_working_example.py", line 44, in forward
out = self.model(x)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
out = self.model(x)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/timm/models/resnet.py", line 635, in forward
return forward_call(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/timm/models/resnet.py", line 635, in forward
x = self.forward_features(x)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/timm/models/resnet.py", line 614, in forward_features
x = self.forward_features(x)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/timm/models/resnet.py", line 614, in forward_features
x = self.conv1(x)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
x = self.conv1(x)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 164, in forward
return self._call_impl(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self.checkpoint_fn( # type: ignore[misc]
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return forward_call(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 164, in forward
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return self.checkpoint_fn( # type: ignore[misc]
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return fn(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 458, in checkpoint
return fn(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
ret = function(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return fn(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 458, in checkpoint
ret = function(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
return self._call_impl(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._conv_forward(input, self.weight, self.bias)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: weight should have at least three dimensions
return forward_call(*args, **kwargs)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/shared/nfs/home/user/venv/torch21/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: weight should have at least three dimensions
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.4.0): 2.3.3
#- PyTorch Version (e.g., 2.4): 2.1.0
#- Python version (e.g., 3.12): 3.10.5
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 118
#- GPU models and configuration: 2 TeslaV100
#- How you installed Lightning(`conda`, `pip`, source): pip
More info
No response
cc @lantiga
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstrategy: fsdpFully Sharded Data ParallelFully Sharded Data Parallelver: 2.3.x