Skip to content

Fabric does not sync gradients? #20293

@RuABraun

Description

@RuABraun

Bug description

I'm reproducing a setup I have in pytorch lightning in fabric. I've been struggling for a while to figure out why when training with fabric the loss plateaus although the setup is identical to pytorch lightning (PL) in terms of data.

When doing single GPU training I get matching loss values. But when I switch to using 8 GPUs the PL training has half the loss after 200 steps and keeps trending down while the fabric training gets stuck. I'm using 32-true precision in both, I'm not using gradient accumulation and the fabric code is not using no_backward_sync().

I decided to check the gradients on each rank, for the PL training I did:

def optimizer_step(self, *args, **kwargs):                                                                       
        super().optimizer_step(*args, **kwargs)
        if self.global_step % 100 == 0:
            if self.student_ctc.weight.grad is not None: 
                print(rank_zero_only.rank, model.out.bias.grad.abs().sum(), flush=True)  

which shows identical gradients on each rank as expected.

In fabric I add a line after optimizer.step()

print(fabric.global_rank, model.out.bias.grad.abs().sum(),flush=True) 

And it's showing different gradients on different ranks.

I'm wondering whether this may be because of what I'm doing: I'm distilling a model so I have a teacher model with frozen grads and a student model that I want to train. I use a wrapper nn.ModuleDict to, in a nested fashion, hold the teacher and student modules, and then I call fabric.setup_module(models) and I only pass the student parameters to the optimizer that then gets passed to fabric.setup_optimizers(optimizer).

Could it be fabric assumes that one has a single nn.Module() which gets called (I know the forward gets overridden and wonder whether there be also some hooks added that matter for gradient syncing) and this one's forward() needs to be called for DDP to work?

What version are you seeing the problem on?

v2.2

How to reproduce the bug

Code looks like

fabric = L.Fabric(
        accelerator="gpu",
        num_nodes=args.num_nodes,
        devices=args.num_devices,
        strategy=strategy,
        precision=config.general.precision,
    )

fabric.launch()
[..]
fabric.setup_module(models)
optimizer = create_optimizer(params, models.student)
optimizer = fabric.setup_optimizers(optimizer)
[..]
# forward pass calls
model.teacher.module(input)
model.teacher.module2(input)
model.student.module(input)
model.student.module2(input)

and gets called with srun python train.py

Can't share exact code.

### Environment

  • Lightning:
    • lightning: 2.2.3
    • lightning-cloud: 0.5.42
    • lightning-utilities: 0.11.2
    • pytorch-lightning: 2.2.3
    • pytorch-wpe: 0.0.1
    • torch: 2.3.0+cu121
    • torch-complex: 0.4.3
    • torchaudio: 2.3.0+cu121
    • torchdata: 0.7.1
    • torchmetrics: 1.4.0.post0
    • torchvision: 0.18.0+cu121

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.2.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions