Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pina/solver/physics_informed_solver/pinn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def setup(self, stage):
"""
# Override the compilation, compiling only for torch < 2.8, see
# related issue at https://github.com/mathLab/PINA/issues/621
if torch.__version__ < "2.8":
self.trainer.compile = True
else:
if torch.__version__ >= "2.8":
self.trainer.compile = False
warnings.warn(
"Compilation is disabled for torch >= 2.8. "
Expand Down
6 changes: 1 addition & 5 deletions pina/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,7 @@ def setup(self, stage):
:return: The result of the parent class ``setup`` method.
:rtype: Any
"""
if stage == "fit" and self.trainer.compile:
self._setup_compile()
if stage == "test" and (
self.trainer.compile and not self._is_compiled()
):
if self.trainer.compile and not self._is_compiled():
self._setup_compile()
return super().setup(stage)

Expand Down
41 changes: 38 additions & 3 deletions pina/trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""Module for the Trainer."""

import sys
import warnings
import torch
import lightning
from .utils import check_consistency
from .utils import check_consistency, custom_warning_format
from .data import PinaDataModule
from .solver import SolverInterface, PINNInterface

# set the warning for compile options
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=UserWarning)


class Trainer(lightning.pytorch.Trainer):
"""
Expand Down Expand Up @@ -49,7 +54,8 @@ def __init__(
:param float val_size: The percentage of elements to include in the
validation dataset. Default is ``0.0``.
:param bool compile: If ``True``, the model is compiled before training.
Default is ``False``. For Windows users, it is always disabled.
Default is ``False``. For Windows users, it is always disabled. Not
supported for python version greater or equal than 3.14.
:param bool repeat: Whether to repeat the dataset data in each
condition during training. For further details, see the
:class:`~pina.data.data_module.PinaDataModule` class. Default is
Expand Down Expand Up @@ -104,8 +110,17 @@ def __init__(
super().__init__(**kwargs)

# checking compilation and automatic batching
if compile is None or sys.platform == "win32":
# compilation disabled for Windows and for Python 3.14+
if (
compile is None
or sys.platform == "win32"
or sys.version_info >= (3, 14)
):
compile = False
warnings.warn(
"Compilation is disabled for Python 3.14+ and for Windows.",
UserWarning,
)

repeat = repeat if repeat is not None else False

Expand Down Expand Up @@ -325,3 +340,23 @@ def _check_consistency_and_set_defaults(
if batch_size is not None:
check_consistency(batch_size, int)
return pin_memory, num_workers, shuffle, batch_size

@property
def compile(self):
"""
Whether compilation is required or not.

:return: ``True`` if compilation is required, ``False`` otherwise.
:rtype: bool
"""
return self._compile

@compile.setter
def compile(self, value):
"""
Setting the value of compile.

:param bool value: Whether compilation is required or not.
"""
check_consistency(value, bool)
self._compile = value