-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Summary
When running the finetuning stage of AptaTrans with Lightning precision=16-mixed (AMP), training crashes immediately with:
RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
This happens inside pyaptamer/aptatrans/_model_lightning.py during finetune loss computation.
Pretraining runs fine (both apta and prot encoders). The crash only appears in finetuning.
Error / Traceback
(Attaching full log as file as well)
Traceback (most recent call last):
0: [rank0]: File "/home/siddharth/work/pyaptamer/train_aptatrans.py", line 535, in <module>
0: [rank0]: main()
0: [rank0]: ~~~~^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/train_aptatrans.py", line 526, in main
0: [rank0]: trainer.fit(lit, train_dataloaders=train_dl, ckpt_path=ckpt_path)
0: [rank0]: ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 584, in fit
0: [rank0]: call._call_and_handle_interrupt(
0: [rank0]: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
0: [rank0]: self,
0: [rank0]: ^^^^^
0: [rank0]: ...<6 lines>...
0: [rank0]: weights_only,
0: [rank0]: ^^^^^^^^^^^^^
0: [rank0]: )
0: [rank0]: ^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
0: [rank0]: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
0: [rank0]: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
0: [rank0]: return function(*args, **kwargs)
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 630, in _fit_impl
0: [rank0]: self._run(model, ckpt_path=ckpt_path, weights_only=weights_only)
0: [rank0]: ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 1079, in _run
0: [rank0]: results = self._run_stage()
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 1123, in _run_stage
0: [rank0]: self.fit_loop.run()
0: [rank0]: ~~~~~~~~~~~~~~~~~^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/fit_loop.py", line 217, in run
0: [rank0]: self.advance()
0: [rank0]: ~~~~~~~~~~~~^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/fit_loop.py", line 465, in advance
0: [rank0]: self.epoch_loop.run(self._data_fetcher)
0: [rank0]: ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 153, in run
0: [rank0]: self.advance(data_fetcher)
0: [rank0]: ~~~~~~~~~~~~^^^^^^^^^^^^^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 352, in advance
0: [rank0]: batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 185, in run
0: [rank0]: closure()
0: [rank0]: ~~~~~~~^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in __call__
0: [rank0]: self._result = self.closure(*args, **kwargs)
0: [rank0]: ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
0: [rank0]: return func(*args, **kwargs)
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 131, in closure
0: [rank0]: step_output = self._step_fn()
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 319, in _training_step
0: [rank0]: training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/call.py", line 329, in _call_strategy_hook
0: [rank0]: output = fn(*args, **kwargs)
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
0: [rank0]: return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
0: [rank0]: ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/strategies/strategy.py", line 641, in __call__
0: [rank0]: wrapper_output = wrapper_module(*args, **kwargs)
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
0: [rank0]: return self._call_impl(*args, **kwargs)
0: [rank0]: ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
0: [rank0]: return forward_call(*args, **kwargs)
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/parallel/distributed.py", line 1666, in forward
0: [rank0]: else self._run_ddp_forward(*inputs, **kwargs)
0: [rank0]: ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/parallel/distributed.py", line 1492, in _run_ddp_forward
0: [rank0]: return self.module(*inputs, **kwargs) # type: ignore[index]
0: [rank0]: ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
0: [rank0]: return self._call_impl(*args, **kwargs)
0: [rank0]: ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
0: [rank0]: return forward_call(*args, **kwargs)
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/lightning/pytorch/strategies/strategy.py", line 634, in wrapped_forward
0: [rank0]: out = method(*_args, **_kwargs)
0: [rank0]: File "/home/siddharth/work/pyaptamer/pyaptamer/aptatrans/_model_lightning.py", line 122, in training_step
0: [rank0]: loss, accuracy = self._step(batch, batch_idx)
0: [rank0]: ~~~~~~~~~~^^^^^^^^^^^^^^^^^^
0: [rank0]: File "/home/siddharth/work/pyaptamer/pyaptamer/aptatrans/_model_lightning.py", line 97, in _step
0: [rank0]: loss = F.binary_cross_entropy(y_hat, y.float())
0: [rank0]: File "/home/siddharth/work/pyaptamer/.venv/lib/python3.13/site-packages/torch/nn/functional.py", line 3574, in binary_cross_entropy
0: [rank0]: return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
0: [rank0]: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
0: [rank0]: Many models use a sigmoid layer right before the binary cross entropy layer.
0: [rank0]: In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits
0: [rank0]: or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are
0: [rank0]: safe to autocast.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working