Skip to content

Commit 973d0c0

Browse files
authored
fix compile issue (#627)
1 parent efc9e32 commit 973d0c0

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

pina/solver/physics_informed_solver/pinn_interface.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
"""Module for the Physics-Informed Neural Network Interface."""
22

33
from abc import ABCMeta, abstractmethod
4+
import warnings
45
import torch
56

7+
from ...utils import custom_warning_format
68
from ..supervised_solver import SupervisedSolverInterface
79
from ...condition import (
810
InputTargetCondition,
911
InputEquationCondition,
1012
DomainEquationCondition,
1113
)
1214

15+
# set the warning for torch >= 2.8 compile
16+
warnings.formatwarning = custom_warning_format
17+
warnings.filterwarnings("always", category=UserWarning)
18+
1319

1420
class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
1521
"""
@@ -46,6 +52,36 @@ def __init__(self, **kwargs):
4652
# current condition name
4753
self.__metric = None
4854

55+
def setup(self, stage):
56+
"""
57+
Setup method executed at the beginning of training and testing.
58+
59+
This method compiles the model only if the installed torch version
60+
is earlier than 2.8, due to known issues with later versions
61+
(see https://github.com/mathLab/PINA/issues/621).
62+
63+
.. warning::
64+
For torch >= 2.8, compilation is disabled. Forcing compilation
65+
on these versions may cause runtime errors or unstable behavior.
66+
67+
:param str stage: The current stage of the training process
68+
(e.g., ``fit``, ``validate``, ``test``, ``predict``).
69+
:return: The result of the parent class ``setup`` method.
70+
:rtype: Any
71+
"""
72+
# Override the compilation, compiling only for torch < 2.8, see
73+
# related issue at https://github.com/mathLab/PINA/issues/621
74+
if torch.__version__ < "2.8":
75+
self.trainer.compile = True
76+
else:
77+
self.trainer.compile = False
78+
warnings.warn(
79+
"Compilation is disabled for torch >= 2.8. "
80+
"Forcing compilation may cause runtime errors or instability.",
81+
UserWarning,
82+
)
83+
return super().setup(stage)
84+
4985
def optimization_cycle(self, batch, loss_residuals=None):
5086
"""
5187
The optimization cycle for the PINN solver.

pina/solver/solver.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,10 @@ def setup(self, stage):
169169
compile the model if the :class:`~pina.trainer.Trainer`
170170
``compile`` is ``True``.
171171
172-
172+
:param str stage: The current stage of the training process
173+
(e.g., ``fit``, ``validate``, ``test``, ``predict``).
174+
:return: The result of the parent class ``setup`` method.
175+
:rtype: Any
173176
"""
174177
if stage == "fit" and self.trainer.compile:
175178
self._setup_compile()

0 commit comments

Comments
 (0)