diff --git a/pina/solver/physics_informed_solver/pinn_interface.py b/pina/solver/physics_informed_solver/pinn_interface.py index 9155e19ec..cd0e25d23 100644 --- a/pina/solver/physics_informed_solver/pinn_interface.py +++ b/pina/solver/physics_informed_solver/pinn_interface.py @@ -71,10 +71,8 @@ def setup(self, stage): """ # Override the compilation, compiling only for torch < 2.8, see # related issue at https://github.com/mathLab/PINA/issues/621 - if torch.__version__ < "2.8": - self.trainer.compile = True - else: - self.trainer.compile = False + if torch.__version__ >= "2.8": + self.trainer._compile = False warnings.warn( "Compilation is disabled for torch >= 2.8. " "Forcing compilation may cause runtime errors or instability.", diff --git a/pina/solver/solver.py b/pina/solver/solver.py index 6948ec664..442574224 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -174,11 +174,7 @@ def setup(self, stage): :return: The result of the parent class ``setup`` method. :rtype: Any """ - if stage == "fit" and self.trainer.compile: - self._setup_compile() - if stage == "test" and ( - self.trainer.compile and not self._is_compiled() - ): + if self.trainer.compile and not self._is_compiled(): self._setup_compile() return super().setup(stage) diff --git a/pina/trainer.py b/pina/trainer.py index 78dd77adf..bb819905f 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -3,10 +3,15 @@ import sys import torch import lightning -from .utils import check_consistency +import warnings +from .utils import check_consistency, custom_warning_format from .data import PinaDataModule from .solver import SolverInterface, PINNInterface +# set the warning for compile options +warnings.formatwarning = custom_warning_format +warnings.filterwarnings("always", category=UserWarning) + class Trainer(lightning.pytorch.Trainer): """ @@ -49,7 +54,8 @@ def __init__( :param float val_size: The percentage of elements to include in the validation dataset. Default is ``0.0``. :param bool compile: If ``True``, the model is compiled before training. - Default is ``False``. For Windows users, it is always disabled. + Default is ``False``. For Windows users, it is always disabled. Not + supported for python version greater or equal than 3.14. :param bool repeat: Whether to repeat the dataset data in each condition during training. For further details, see the :class:`~pina.data.data_module.PinaDataModule` class. Default is @@ -104,8 +110,20 @@ def __init__( super().__init__(**kwargs) # checking compilation and automatic batching - if compile is None or sys.platform == "win32": - compile = False + # compile disambled for windows and py>=3.14 + compile = False + # if ( + # compile is None + # or sys.platform == "win32" + # or sys.version_info >= (3, 14) + # ): + # compile = False + # raise KeyError + # warnings.warn( + # "Compilation is disabled for Python 3.14+. " + # "Compilation is also disabled for Windows 3.2.", + # UserWarning, + # ) repeat = repeat if repeat is not None else False @@ -114,7 +132,7 @@ def __init__( ) # set attributes - self.compile = compile + self._compile = compile self.solver = solver self.batch_size = batch_size self._move_to_device() @@ -325,3 +343,7 @@ def _check_consistency_and_set_defaults( if batch_size is not None: check_consistency(batch_size, int) return pin_memory, num_workers, shuffle, batch_size + + @property + def compile(self): + return self._compile \ No newline at end of file