Skip to content

[BUG] Finetuning crashes with AMP: binary_cross_entropy is unsafe to autocast (needs BCEWithLogitsLoss or disable autocast for loss) #230

@siddharth7113

Description

@siddharth7113

Summary

When running the finetuning stage of AptaTrans with Lightning precision=16-mixed (AMP), training crashes immediately with:

RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.

This happens inside pyaptamer/aptatrans/_model_lightning.py during finetune loss computation.

Pretraining runs fine (both apta and prot encoders). The crash only appears in finetuning.

Error / Traceback

(Attaching full log as file as well)

Traceback (most recent call last):
0: [rank0]:   File "/home/siddharth/work/pyaptamer/train_aptatrans.py", line 535, in <module>
0: [rank0]:     main()
0: [rank0]:     ~~~~^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/train_aptatrans.py", line 526, in main
0: [rank0]:     trainer.fit(lit, train_dataloaders=train_dl, ckpt_path=ckpt_path)
0: [rank0]:     ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 584, in fit
0: [rank0]:     call._call_and_handle_interrupt(
0: [rank0]:     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
0: [rank0]:         self,
0: [rank0]:         ^^^^^
0: [rank0]:     ...<6 lines>...
0: [rank0]:         weights_only,
0: [rank0]:         ^^^^^^^^^^^^^
0: [rank0]:     )
0: [rank0]:     ^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
0: [rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
0: [rank0]:            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
0: [rank0]:     return function(*args, **kwargs)
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 630, in _fit_impl
0: [rank0]:     self._run(model, ckpt_path=ckpt_path, weights_only=weights_only)
0: [rank0]:     ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 1079, in _run
0: [rank0]:     results = self._run_stage()
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 1123, in _run_stage
0: [rank0]:     self.fit_loop.run()
0: [rank0]:     ~~~~~~~~~~~~~~~~~^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/fit_loop.py", line 217, in run
0: [rank0]:     self.advance()
0: [rank0]:     ~~~~~~~~~~~~^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/fit_loop.py", line 465, in advance
0: [rank0]:     self.epoch_loop.run(self._data_fetcher)
0: [rank0]:     ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 153, in run
0: [rank0]:     self.advance(data_fetcher)
0: [rank0]:     ~~~~~~~~~~~~^^^^^^^^^^^^^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 352, in advance
0: [rank0]:     batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 185, in run
0: [rank0]:     closure()
0: [rank0]:     ~~~~~~~^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in __call__
0: [rank0]:     self._result = self.closure(*args, **kwargs)
0: [rank0]:                    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
0: [rank0]:     return func(*args, **kwargs)
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 131, in closure
0: [rank0]:     step_output = self._step_fn()
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 319, in _training_step
0: [rank0]:     training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/call.py", line 329, in _call_strategy_hook
0: [rank0]:     output = fn(*args, **kwargs)
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
0: [rank0]:     return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
0: [rank0]:            ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/strategies/strategy.py", line 641, in __call__
0: [rank0]:     wrapper_output = wrapper_module(*args, **kwargs)
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
0: [rank0]:     return self._call_impl(*args, **kwargs)
0: [rank0]:            ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
0: [rank0]:     return forward_call(*args, **kwargs)
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/parallel/distributed.py", line 1666, in forward
0: [rank0]:     else self._run_ddp_forward(*inputs, **kwargs)
0: [rank0]:          ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/parallel/distributed.py", line 1492, in _run_ddp_forward
0: [rank0]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
0: [rank0]:            ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
0: [rank0]:     return self._call_impl(*args, **kwargs)
0: [rank0]:            ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
0: [rank0]:     return forward_call(*args, **kwargs)
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/strategies/strategy.py", line 634, in wrapped_forward
0: [rank0]:     out = method(*_args, **_kwargs)
0: [rank0]:   File "/home/siddharth/work/pyaptamer/pyaptamer/aptatrans/_model_lightning.py", line 122, in training_step
0: [rank0]:     loss, accuracy = self._step(batch, batch_idx)
0: [rank0]:                      ~~~~~~~~~~^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/home/siddharth/work/pyaptamer/pyaptamer/aptatrans/_model_lightning.py", line 97, in _step
0: [rank0]:     loss = F.binary_cross_entropy(y_hat, y.float())
0: [rank0]:   File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/functional.py", line 3574, in binary_cross_entropy
0: [rank0]:     return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
0: [rank0]:            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
0: [rank0]: Many models use a sigmoid layer right before the binary cross entropy layer.
0: [rank0]: In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits
0: [rank0]: or torch.nn.BCEWithLogitsLoss.  binary_cross_entropy_with_logits and BCEWithLogits are
0: [rank0]: safe to autocast.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions