diff --git a/pina/solver/physics_informed_solver/pinn_interface.py b/pina/solver/physics_informed_solver/pinn_interface.py index 9155e19ec..65a0dd78f 100644 --- a/pina/solver/physics_informed_solver/pinn_interface.py +++ b/pina/solver/physics_informed_solver/pinn_interface.py @@ -71,9 +71,7 @@ def setup(self, stage): """ # Override the compilation, compiling only for torch < 2.8, see # related issue at https://github.com/mathLab/PINA/issues/621 - if torch.__version__ < "2.8": - self.trainer.compile = True - else: + if torch.__version__ >= "2.8": self.trainer.compile = False warnings.warn( "Compilation is disabled for torch >= 2.8. " diff --git a/pina/solver/solver.py b/pina/solver/solver.py index 6948ec664..442574224 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -174,11 +174,7 @@ def setup(self, stage): :return: The result of the parent class ``setup`` method. :rtype: Any """ - if stage == "fit" and self.trainer.compile: - self._setup_compile() - if stage == "test" and ( - self.trainer.compile and not self._is_compiled() - ): + if self.trainer.compile and not self._is_compiled(): self._setup_compile() return super().setup(stage) diff --git a/pina/trainer.py b/pina/trainer.py index 78dd77adf..8e1d95110 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,12 +1,17 @@ """Module for the Trainer.""" import sys +import warnings import torch import lightning -from .utils import check_consistency +from .utils import check_consistency, custom_warning_format from .data import PinaDataModule from .solver import SolverInterface, PINNInterface +# set the warning for compile options +warnings.formatwarning = custom_warning_format +warnings.filterwarnings("always", category=UserWarning) + class Trainer(lightning.pytorch.Trainer): """ @@ -49,7 +54,8 @@ def __init__( :param float val_size: The percentage of elements to include in the validation dataset. Default is ``0.0``. :param bool compile: If ``True``, the model is compiled before training. - Default is ``False``. For Windows users, it is always disabled. + Default is ``False``. For Windows users, it is always disabled. Not + supported for python version greater or equal than 3.14. :param bool repeat: Whether to repeat the dataset data in each condition during training. For further details, see the :class:`~pina.data.data_module.PinaDataModule` class. Default is @@ -104,8 +110,17 @@ def __init__( super().__init__(**kwargs) # checking compilation and automatic batching - if compile is None or sys.platform == "win32": + # compilation disabled for Windows and for Python 3.14+ + if ( + compile is None + or sys.platform == "win32" + or sys.version_info >= (3, 14) + ): compile = False + warnings.warn( + "Compilation is disabled for Python 3.14+ and for Windows.", + UserWarning, + ) repeat = repeat if repeat is not None else False @@ -325,3 +340,23 @@ def _check_consistency_and_set_defaults( if batch_size is not None: check_consistency(batch_size, int) return pin_memory, num_workers, shuffle, batch_size + + @property + def compile(self): + """ + Whether compilation is required or not. + + :return: ``True`` if compilation is required, ``False`` otherwise. + :rtype: bool + """ + return self._compile + + @compile.setter + def compile(self, value): + """ + Setting the value of compile. + + :param bool value: Whether compilation is required or not. + """ + check_consistency(value, bool) + self._compile = value