-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
3rd partyRelated to a 3rd-partyRelated to a 3rd-partybugSomething isn't workingSomething isn't workingver: 2.4.x
Description
Bug description
Hello,
I am using lightning to train a complex-valued neural networks with complex valued tensor. When I use single gpu training, there is no issue. When I train with multi-gpus with DDP, my training diverges. I try to train on only one gpu, and still declaring " strategy='ddp' " in the trainer, the training also diverge.
I've tried to reproduce the issue with the code sample below. MNIST dataset and the model defined in this sample are simpler than in my current work, so the model won't diverge but really struggle to converge. To check if the issue happens, just comment the line " strategy='ddp' " in the trainer.
This seems to be related to #55375 and #60931
What version are you seeing the problem on?
v2.4
How to reproduce the bug
from typing import List
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.v2 as v2_transforms
import lightning as L
import torchcvnn.nn as c_nn
from torchmetrics.classification import Accuracy
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.progress import TQDMProgressBar
from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm
from lightning.pytorch.utilities import rank_zero_only
def conv_block(in_c: int, out_c: int, cdtype: torch.dtype) -> List[nn.Module]:
return [
nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
c_nn.BatchNorm2d(out_c),
c_nn.Cardioid(),
nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
c_nn.BatchNorm2d(out_c),
c_nn.Cardioid(),
c_nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
]
class TBLogger(TensorBoardLogger):
@rank_zero_only
def log_metrics(self, metrics, step):
metrics.pop('epoch', None)
metrics = {k: v for k, v in metrics.items() if ('step' not in k) and ('val' not in k)}
return super().log_metrics(metrics, step)
class CustomProgressBar(TQDMProgressBar):
def get_metrics(self, trainer, model):
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
return items
def init_train_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for training."""
bar = super().init_train_tqdm()
bar.ascii = ' >'
return bar
def init_validation_tqdm(self):
bar = super().init_validation_tqdm()
bar.ascii = ' >'
return bar
class cMNISTModel(L.LightningModule):
def __init__(self):
super().__init__()
self.ce_loss = nn.CrossEntropyLoss()
self.model = self.configure_model()
self.accuracy = Accuracy(task='multiclass', num_classes=10)
self.train_step_outputs = {}
self.valid_step_outputs = {}
def configure_model(self):
conv_model = nn.Sequential(
*conv_block(1, 16, torch.complex64),
*conv_block(16, 16, torch.complex64),
*conv_block(16, 32, torch.complex64),
*conv_block(32, 32, torch.complex64),
nn.Flatten(),
)
with torch.no_grad():
conv_model.eval()
dummy_input = torch.zeros((64, 1, 28, 28), dtype=torch.complex64, requires_grad=False)
out_conv = conv_model(dummy_input).view(64, -1)
lin_model = nn.Sequential(
nn.Linear(out_conv.shape[-1], 124, dtype=torch.complex64),
c_nn.Cardioid(),
nn.Linear(124, 10, dtype=torch.complex64),
c_nn.Mod(),
)
return nn.Sequential(conv_model, lin_model)
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
return torch.optim.Adam(params=self.parameters(), lr=3e-4)
def training_step(self, batch, batch_idx):
data, label = batch
logits = self(data)
loss = self.ce_loss(logits, label)
acc = self.accuracy(logits, label)
self.log('step_loss', loss, prog_bar=True, sync_dist=True)
self.log('step_metrics', acc, prog_bar=True, sync_dist=True)
if not self.train_step_outputs:
self.train_step_outputs = {
'step_loss': [loss],
'step_metrics': [acc]
}
else:
self.train_step_outputs['step_loss'].append(loss)
self.train_step_outputs['step_metrics'].append(acc)
return loss
def validation_step(self, batch: torch.Tensor, batch_idx: int):
images, labels = batch
logits = self(images)
loss = self.ce_loss(logits, labels)
acc = self.accuracy(logits, labels)
self.log('step_loss', loss, prog_bar=True, sync_dist=True)
self.log('step_metrics', acc, prog_bar=True, sync_dist=True)
if not self.valid_step_outputs:
self.valid_step_outputs = {
'step_loss': [loss],
'step_metrics': [acc]
}
else:
self.valid_step_outputs['step_loss'].append(loss)
self.valid_step_outputs['step_metrics'].append(acc)
def on_train_epoch_end(self) -> None:
_log_dict = {
'Loss/loss': torch.tensor(self.train_step_outputs['step_loss']).mean(),
'Metrics/accuracy': torch.tensor(self.train_step_outputs['step_metrics']).mean()
}
self.loggers[0].log_metrics(_log_dict, self.current_epoch)
self.train_step_outputs.clear()
def on_validation_epoch_end(self) -> None:
mean_loss_value = torch.tensor(self.valid_step_outputs['step_loss']).mean()
mean_metrics_value = torch.tensor(self.valid_step_outputs['step_metrics']).mean()
_log_dict = {
'Loss/loss': mean_loss_value,
'Metrics/accuracy': mean_metrics_value
}
self.loggers[1].log_metrics(_log_dict, self.current_epoch)
self.log('val_loss', mean_loss_value, sync_dist=True)
self.log('val_Accuracy', mean_metrics_value, sync_dist=True)
self.valid_step_outputs.clear()
def train():
batch_size = 64
epochs = 10
torch.set_float32_matmul_precision('high')
# Dataloading
train_dataset = torchvision.datasets.MNIST(
root="./data",
train=True,
download=True,
transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]),
)
valid_dataset = torchvision.datasets.MNIST(
root="./data",
train=False,
download=True,
transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]),
)
# Train dataloader
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
persistent_workers=True,
pin_memory=True
)
# Valid dataloader
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
persistent_workers=True,
pin_memory=True
)
model = cMNISTModel()
trainer = L.Trainer(
max_epochs=epochs,
strategy='ddp_find_unused_parameters_true',
num_sanity_val_steps=0,
benchmark=True,
enable_checkpointing=True,
callbacks=[
CustomProgressBar(),
EarlyStopping(
monitor='val_loss',
verbose=True,
patience=5,
min_delta=0.005
),
LearningRateMonitor(logging_interval='epoch'),
ModelCheckpoint(
dirpath='weights_storage_/',
monitor='val_Accuracy',
verbose=True,
mode='max'
)
],
logger=[
TBLogger('training_logs_', name=None, sub_dir='train'),
TBLogger('training_logs_', name=None, sub_dir='valid')
]
)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=valid_loader)
if __name__ == "__main__":
train()Error messages and logs
No response
Environment
Current environment
#- PyTorch Lightning Version: 2.4.0
#- PyTorch Version: 2.5.1
#- Python version: 3.12.7
#- OS: Linux Ubuntu 24.04.1 or Slurm
#- CUDA/cuDNN version: 12.4
#- GPU models and configuration: RTX 4090 (Ubuntu pc), NVIDIA A100 40G (Slurm)
#- How you installed Lightning: pip
More info
@jeremyfix @QuentinGABOT might also be interested in this issue
Metadata
Metadata
Assignees
Labels
3rd partyRelated to a 3rd-partyRelated to a 3rd-partybugSomething isn't workingSomething isn't workingver: 2.4.x