Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions pina/solver/physics_informed_solver/pinn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,8 @@ def setup(self, stage):
"""
# Override the compilation, compiling only for torch < 2.8, see
# related issue at https://github.com/mathLab/PINA/issues/621
if torch.__version__ < "2.8":
self.trainer.compile = True
else:
self.trainer.compile = False
if torch.__version__ >= "2.8":
self.trainer._compile = False
warnings.warn(
"Compilation is disabled for torch >= 2.8. "
"Forcing compilation may cause runtime errors or instability.",
Expand Down
6 changes: 1 addition & 5 deletions pina/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,7 @@ def setup(self, stage):
:return: The result of the parent class ``setup`` method.
:rtype: Any
"""
if stage == "fit" and self.trainer.compile:
self._setup_compile()
if stage == "test" and (
self.trainer.compile and not self._is_compiled()
):
if self.trainer.compile and not self._is_compiled():
self._setup_compile()
return super().setup(stage)

Expand Down
32 changes: 27 additions & 5 deletions pina/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
import sys
import torch
import lightning
from .utils import check_consistency
import warnings
from .utils import check_consistency, custom_warning_format
from .data import PinaDataModule
from .solver import SolverInterface, PINNInterface

# set the warning for compile options
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=UserWarning)


class Trainer(lightning.pytorch.Trainer):
"""
Expand Down Expand Up @@ -49,7 +54,8 @@ def __init__(
:param float val_size: The percentage of elements to include in the
validation dataset. Default is ``0.0``.
:param bool compile: If ``True``, the model is compiled before training.
Default is ``False``. For Windows users, it is always disabled.
Default is ``False``. For Windows users, it is always disabled. Not
supported for python version greater or equal than 3.14.
:param bool repeat: Whether to repeat the dataset data in each
condition during training. For further details, see the
:class:`~pina.data.data_module.PinaDataModule` class. Default is
Expand Down Expand Up @@ -104,8 +110,20 @@ def __init__(
super().__init__(**kwargs)

# checking compilation and automatic batching
if compile is None or sys.platform == "win32":
compile = False
# compile disambled for windows and py>=3.14
compile = False
# if (
# compile is None
# or sys.platform == "win32"
# or sys.version_info >= (3, 14)
# ):
# compile = False
# raise KeyError
# warnings.warn(
# "Compilation is disabled for Python 3.14+. "
# "Compilation is also disabled for Windows 3.2.",
# UserWarning,
# )

repeat = repeat if repeat is not None else False

Expand All @@ -114,7 +132,7 @@ def __init__(
)

# set attributes
self.compile = compile
self._compile = compile
self.solver = solver
self.batch_size = batch_size
self._move_to_device()
Expand Down Expand Up @@ -325,3 +343,7 @@ def _check_consistency_and_set_defaults(
if batch_size is not None:
check_consistency(batch_size, int)
return pin_memory, num_workers, shuffle, batch_size

@property
def compile(self):
return self._compile
Loading