Skip to content

Mixed precision, ddp and torch.no_grad()Β #20251

@tomasgeffner

Description

@tomasgeffner

Bug description

I'm training a model and insidetraining_step() I need to call my neural network twice, the first one outputs something that is detached and then fed as input to the second call. Something like the following

with torch.no_grad():
    x_pred = self.nn(x_input)
x_pred_2 = self.nn(x_input, x_pred)

It all works well with the ddp strategy for multi-device training. However, when I enable mixed precision, I get RuntimeError: It looks like your LightningModule has parameters that were not used in producing the loss returned by training_step. If this is intentional, you must enable the detection of unused parameters in DDP, either by setting the string value strategy='ddp_find_unused_parameters_true' or by setting the flag in the strategy with strategy=DDPStrategy(find_unused_parameters=True).

I saw that with mixed precision there's some cacheing happening and then the grads remain off after the first torch.no_grad() call. What is the right way to do this? Is this documented in detail somewhere?

One option is to use .detach() on x_pred but this would actually use more memory than needed since the forward pass is done normally, without gradients off.

Thanks

What version are you seeing the problem on?

v2.1

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

cc @justusschock @lantiga

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions