-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
I'm training a model and insidetraining_step()
I need to call my neural network twice, the first one outputs something that is detached and then fed as input to the second call. Something like the following
with torch.no_grad():
x_pred = self.nn(x_input)
x_pred_2 = self.nn(x_input, x_pred)
It all works well with the ddp strategy for multi-device training. However, when I enable mixed precision, I get RuntimeError: It looks like your LightningModule has parameters that were not used in producing the loss returned by training_step. If this is intentional, you must enable the detection of unused parameters in DDP, either by setting the string value strategy='ddp_find_unused_parameters_true' or by setting the flag in the strategy with strategy=DDPStrategy(find_unused_parameters=True).
I saw that with mixed precision there's some cacheing happening and then the grads remain off after the first torch.no_grad()
call. What is the right way to do this? Is this documented in detail somewhere?
One option is to use .detach()
on x_pred
but this would actually use more memory than needed since the forward pass is done normally, without gradients off.
Thanks
What version are you seeing the problem on?
v2.1
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response