Skip to content

Model diverges or struggles to converge with complex-valued tensors in DDP #20480

@ouioui199

Description

@ouioui199

Bug description

Hello,

I am using lightning to train a complex-valued neural networks with complex valued tensor. When I use single gpu training, there is no issue. When I train with multi-gpus with DDP, my training diverges. I try to train on only one gpu, and still declaring " strategy='ddp' " in the trainer, the training also diverge.

I've tried to reproduce the issue with the code sample below. MNIST dataset and the model defined in this sample are simpler than in my current work, so the model won't diverge but really struggle to converge. To check if the issue happens, just comment the line " strategy='ddp' " in the trainer.

This seems to be related to #55375 and #60931

What version are you seeing the problem on?

v2.4

How to reproduce the bug

from typing import List

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.v2 as v2_transforms
import lightning as L
import torchcvnn.nn as c_nn
from torchmetrics.classification import Accuracy
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.progress import TQDMProgressBar
from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm
from lightning.pytorch.utilities import rank_zero_only


def conv_block(in_c: int, out_c: int, cdtype: torch.dtype) -> List[nn.Module]:
    return [
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
        c_nn.BatchNorm2d(out_c),
        c_nn.Cardioid(),
        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
        c_nn.BatchNorm2d(out_c),
        c_nn.Cardioid(),
        c_nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
    ]


class TBLogger(TensorBoardLogger):
    @rank_zero_only
    def log_metrics(self, metrics, step):
        metrics.pop('epoch', None)
        metrics = {k: v for k, v in metrics.items() if ('step' not in k) and ('val' not in k)}
        return super().log_metrics(metrics, step)
    
    
class CustomProgressBar(TQDMProgressBar):
    
    def get_metrics(self, trainer, model):
        items = super().get_metrics(trainer, model)
        items.pop("v_num", None)
        return items
    
    def init_train_tqdm(self) -> Tqdm:
        """Override this to customize the tqdm bar for training."""
        bar = super().init_train_tqdm()
        bar.ascii = ' >'
        return bar
    
    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        bar.ascii = ' >'
        return bar


class cMNISTModel(L.LightningModule):

    def __init__(self):
        super().__init__()

        self.ce_loss = nn.CrossEntropyLoss()
        self.model = self.configure_model()
        self.accuracy = Accuracy(task='multiclass', num_classes=10)
        
        self.train_step_outputs = {}
        self.valid_step_outputs = {}

    def configure_model(self):
        conv_model = nn.Sequential(
            *conv_block(1, 16, torch.complex64),
            *conv_block(16, 16, torch.complex64),
            *conv_block(16, 32, torch.complex64),
            *conv_block(32, 32, torch.complex64),
            nn.Flatten(),
        )

        with torch.no_grad():
            conv_model.eval()
            dummy_input = torch.zeros((64, 1, 28, 28), dtype=torch.complex64, requires_grad=False)
            out_conv = conv_model(dummy_input).view(64, -1)
        lin_model = nn.Sequential(
            nn.Linear(out_conv.shape[-1], 124, dtype=torch.complex64),
            c_nn.Cardioid(),
            nn.Linear(124, 10, dtype=torch.complex64),
            c_nn.Mod(),
        )

        return nn.Sequential(conv_model, lin_model)
    
    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(params=self.parameters(), lr=3e-4)
    
    def training_step(self, batch, batch_idx):
        data, label = batch
        logits = self(data)

        loss = self.ce_loss(logits, label)
        acc = self.accuracy(logits, label)

        self.log('step_loss', loss, prog_bar=True, sync_dist=True)
        self.log('step_metrics', acc, prog_bar=True, sync_dist=True)
        
        if not self.train_step_outputs:
            self.train_step_outputs = {
                'step_loss': [loss],
                'step_metrics': [acc]
            }
        else:
            self.train_step_outputs['step_loss'].append(loss)
            self.train_step_outputs['step_metrics'].append(acc)

        return loss

    def validation_step(self, batch: torch.Tensor, batch_idx: int):
        images, labels = batch
        logits = self(images)

        loss = self.ce_loss(logits, labels)
        acc = self.accuracy(logits, labels)
        self.log('step_loss', loss, prog_bar=True, sync_dist=True)
        self.log('step_metrics', acc, prog_bar=True, sync_dist=True)
        
        if not self.valid_step_outputs:
            self.valid_step_outputs = {
                'step_loss': [loss],
                'step_metrics': [acc]
            }
        else:
            self.valid_step_outputs['step_loss'].append(loss)
            self.valid_step_outputs['step_metrics'].append(acc)

    def on_train_epoch_end(self) -> None:
        _log_dict = {
            'Loss/loss': torch.tensor(self.train_step_outputs['step_loss']).mean(),
            'Metrics/accuracy': torch.tensor(self.train_step_outputs['step_metrics']).mean()
        }
        
        self.loggers[0].log_metrics(_log_dict, self.current_epoch)
        self.train_step_outputs.clear()

    def on_validation_epoch_end(self) -> None:
        mean_loss_value = torch.tensor(self.valid_step_outputs['step_loss']).mean()
        mean_metrics_value = torch.tensor(self.valid_step_outputs['step_metrics']).mean()
        
        _log_dict = {
            'Loss/loss': mean_loss_value,
            'Metrics/accuracy': mean_metrics_value
        }
        
        self.loggers[1].log_metrics(_log_dict, self.current_epoch)
        
        self.log('val_loss', mean_loss_value, sync_dist=True)
        self.log('val_Accuracy', mean_metrics_value, sync_dist=True)
        self.valid_step_outputs.clear()


def train():
    batch_size = 64
    epochs = 10
    torch.set_float32_matmul_precision('high')

    # Dataloading
    train_dataset = torchvision.datasets.MNIST(
        root="./data",
        train=True,
        download=True,
        transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]),
    )
    valid_dataset = torchvision.datasets.MNIST(
        root="./data",
        train=False,
        download=True,
        transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]),
    )

    # Train dataloader
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=4,
        persistent_workers=True,
        pin_memory=True
    )

    # Valid dataloader
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=4,
        persistent_workers=True,
        pin_memory=True
    )

    model = cMNISTModel()
    trainer = L.Trainer(
        max_epochs=epochs,
        strategy='ddp_find_unused_parameters_true',
        num_sanity_val_steps=0,
        benchmark=True,
        enable_checkpointing=True,
        callbacks=[
            CustomProgressBar(),
            EarlyStopping(
                monitor='val_loss', 
                verbose=True,
                patience=5,
                min_delta=0.005
            ),
            LearningRateMonitor(logging_interval='epoch'),
            ModelCheckpoint(
                dirpath='weights_storage_/',
                monitor='val_Accuracy', 
                verbose=True, 
                mode='max'
            )
        ],
        logger=[
            TBLogger('training_logs_', name=None, sub_dir='train'),
            TBLogger('training_logs_', name=None, sub_dir='valid')
        ]
    )

    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=valid_loader)
    

if __name__ == "__main__":
    train()

Error messages and logs

No response

Environment

Current environment
#- PyTorch Lightning Version: 2.4.0
#- PyTorch Version: 2.5.1
#- Python version: 3.12.7
#- OS: Linux Ubuntu 24.04.1 or Slurm
#- CUDA/cuDNN version: 12.4
#- GPU models and configuration: RTX 4090 (Ubuntu pc), NVIDIA A100 40G (Slurm)
#- How you installed Lightning: pip

More info

@jeremyfix @QuentinGABOT might also be interested in this issue

Metadata

Metadata

Assignees

No one assigned

    Labels

    3rd partyRelated to a 3rd-partybugSomething isn't workingver: 2.4.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions