|
1 | 1 | """Module for the Physics-Informed Neural Network Interface.""" |
2 | 2 |
|
3 | 3 | from abc import ABCMeta, abstractmethod |
| 4 | +import warnings |
4 | 5 | import torch |
5 | 6 |
|
| 7 | +from ...utils import custom_warning_format |
6 | 8 | from ..supervised_solver import SupervisedSolverInterface |
7 | 9 | from ...condition import ( |
8 | 10 | InputTargetCondition, |
9 | 11 | InputEquationCondition, |
10 | 12 | DomainEquationCondition, |
11 | 13 | ) |
12 | 14 |
|
| 15 | +# set the warning for torch >= 2.8 compile |
| 16 | +warnings.formatwarning = custom_warning_format |
| 17 | +warnings.filterwarnings("always", category=UserWarning) |
| 18 | + |
13 | 19 |
|
14 | 20 | class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta): |
15 | 21 | """ |
@@ -46,6 +52,36 @@ def __init__(self, **kwargs): |
46 | 52 | # current condition name |
47 | 53 | self.__metric = None |
48 | 54 |
|
| 55 | + def setup(self, stage): |
| 56 | + """ |
| 57 | + Setup method executed at the beginning of training and testing. |
| 58 | +
|
| 59 | + This method compiles the model only if the installed torch version |
| 60 | + is earlier than 2.8, due to known issues with later versions |
| 61 | + (see https://github.com/mathLab/PINA/issues/621). |
| 62 | +
|
| 63 | + .. warning:: |
| 64 | + For torch >= 2.8, compilation is disabled. Forcing compilation |
| 65 | + on these versions may cause runtime errors or unstable behavior. |
| 66 | +
|
| 67 | + :param str stage: The current stage of the training process |
| 68 | + (e.g., ``fit``, ``validate``, ``test``, ``predict``). |
| 69 | + :return: The result of the parent class ``setup`` method. |
| 70 | + :rtype: Any |
| 71 | + """ |
| 72 | + # Override the compilation, compiling only for torch < 2.8, see |
| 73 | + # related issue at https://github.com/mathLab/PINA/issues/621 |
| 74 | + if torch.__version__ < "2.8": |
| 75 | + self.trainer.compile = True |
| 76 | + else: |
| 77 | + self.trainer.compile = False |
| 78 | + warnings.warn( |
| 79 | + "Compilation is disabled for torch >= 2.8. " |
| 80 | + "Forcing compilation may cause runtime errors or instability.", |
| 81 | + UserWarning, |
| 82 | + ) |
| 83 | + return super().setup(stage) |
| 84 | + |
49 | 85 | def optimization_cycle(self, batch, loss_residuals=None): |
50 | 86 | """ |
51 | 87 | The optimization cycle for the PINN solver. |
|
0 commit comments