diff --git a/pina/problem/zoo/__init__.py b/pina/problem/zoo/__init__.py index c18d649d7..6e3d58e52 100644 --- a/pina/problem/zoo/__init__.py +++ b/pina/problem/zoo/__init__.py @@ -1,15 +1,19 @@ """TODO""" __all__ = [ - "Poisson2DSquareProblem", "SupervisedProblem", - "InversePoisson2DSquareProblem", + "HelmholtzProblem", + "AllenCahnProblem", + "AdvectionProblem", + "Poisson2DSquareProblem", "DiffusionReactionProblem", - "InverseDiffusionReactionProblem", + "InversePoisson2DSquareProblem", ] -from .poisson_2d_square import Poisson2DSquareProblem from .supervised_problem import SupervisedProblem -from .inverse_poisson_2d_square import InversePoisson2DSquareProblem +from .helmholtz import HelmholtzProblem +from .allen_cahn import AllenCahnProblem +from .advection import AdvectionProblem +from .poisson_2d_square import Poisson2DSquareProblem from .diffusion_reaction import DiffusionReactionProblem -from .inverse_diffusion_reaction import InverseDiffusionReactionProblem +from .inverse_poisson_2d_square import InversePoisson2DSquareProblem diff --git a/pina/problem/zoo/advection.py b/pina/problem/zoo/advection.py new file mode 100644 index 000000000..32c6afe78 --- /dev/null +++ b/pina/problem/zoo/advection.py @@ -0,0 +1,107 @@ +"""Formulation of the advection problem.""" + +import torch +from ... import Condition +from ...operator import grad +from ...equation import Equation +from ...domain import CartesianDomain +from ...utils import check_consistency +from ...problem import SpatialProblem, TimeDependentProblem + + +class AdvectionEquation(Equation): + """ + Implementation of the advection equation. + """ + + def __init__(self, c): + """ + Initialize the advection equation. + + :param c: The advection velocity parameter. + :type c: float | int + """ + self.c = c + check_consistency(self.c, (float, int)) + + def equation(input_, output_): + """ + Implementation of the advection equation. + + :param LabelTensor input_: Input data of the problem. + :param LabelTensor output_: Output data of the problem. + :return: The residual of the advection equation. + :rtype: LabelTensor + """ + u_x = grad(output_, input_, components=["u"], d=["x"]) + u_t = grad(output_, input_, components=["u"], d=["t"]) + return u_t + self.c * u_x + + super().__init__(equation) + + +def initial_condition(input_, output_): + """ + Implementation of the initial condition. + + :param LabelTensor input_: Input data of the problem. + :param LabelTensor output_: Output data of the problem. + :return: The residual of the initial condition. + :rtype: LabelTensor + """ + return output_ - torch.sin(input_.extract("x")) + + +class AdvectionProblem(SpatialProblem, TimeDependentProblem): + r""" + Implementation of the advection problem in the spatial interval + :math:`[0, 2 \pi]` and temporal interval :math:`[0, 1]`. + + .. seealso:: + + **Original reference**: Wang, Sifan, et al. *An expert's guide to + training physics-informed neural networks*. + arXiv preprint arXiv:2308.08468 (2023). + DOI: `arXiv:2308.08468 `_. + """ + + output_variables = ["u"] + spatial_domain = CartesianDomain({"x": [0, 2 * torch.pi]}) + temporal_domain = CartesianDomain({"t": [0, 1]}) + + domains = { + "D": CartesianDomain({"x": [0, 2 * torch.pi], "t": [0, 1]}), + "t0": CartesianDomain({"x": [0, 2 * torch.pi], "t": 0.0}), + } + + conditions = { + "t0": Condition(domain="t0", equation=Equation(initial_condition)), + } + + def __init__(self, c=1.0): + """ + Initialize the advection problem. + + :param c: The advection velocity parameter. + :type c: float | int + """ + super().__init__() + + self.c = c + check_consistency(self.c, (float, int)) + + self.conditions["D"] = Condition( + domain="D", equation=AdvectionEquation(self.c) + ) + + def solution(self, pts): + """ + Implementation of the analytical solution of the advection problem. + + :param LabelTensor pts: Points where the solution is evaluated. + :return: The analytical solution of the advection problem. + :rtype: LabelTensor + """ + sol = torch.sin(pts.extract("x") - self.c * pts.extract("t")) + sol.labels = self.output_variables + return sol diff --git a/pina/problem/zoo/allen_cahn.py b/pina/problem/zoo/allen_cahn.py new file mode 100644 index 000000000..e4a9c4c41 --- /dev/null +++ b/pina/problem/zoo/allen_cahn.py @@ -0,0 +1,66 @@ +"""Formulation of the Allen Cahn problem.""" + +import torch +from ... import Condition +from ...equation import Equation +from ...domain import CartesianDomain +from ...operator import grad, laplacian +from ...problem import SpatialProblem, TimeDependentProblem + + +def allen_cahn_equation(input_, output_): + """ + Implementation of the Allen Cahn equation. + + :param LabelTensor input_: Input data of the problem. + :param LabelTensor output_: Output data of the problem. + :return: The residual of the Allen Cahn equation. + :rtype: LabelTensor + """ + u_t = grad(output_, input_, components=["u"], d=["t"]) + u_xx = laplacian(output_, input_, components=["u"], d=["x"]) + return u_t - 0.0001 * u_xx + 5 * output_**3 - 5 * output_ + + +def initial_condition(input_, output_): + """ + Definition of the initial condition of the Allen Cahn problem. + + :param LabelTensor input_: Input data of the problem. + :param LabelTensor output_: Output data of the problem. + :return: The residual of the initial condition. + :rtype: LabelTensor + """ + x = input_.extract("x") + u_0 = x**2 * torch.cos(torch.pi * x) + return output_ - u_0 + + +class AllenCahnProblem(TimeDependentProblem, SpatialProblem): + r""" + Implementation of the Allen Cahn problem in the spatial interval + :math:`[-1, 1]` and temporal interval :math:`[0, 1]`. + + .. seealso:: + **Original reference**: Sokratis J. Anagnostopoulos, Juan D. Toscano, + Nikolaos Stergiopulos, and George E. Karniadakis. + *Residual-based attention and connection to information + bottleneck theory in PINNs*. + Computer Methods in Applied Mechanics and Engineering 421 (2024): 116805 + DOI: `10.1016/ + j.cma.2024.116805 `_. + """ + + output_variables = ["u"] + spatial_domain = CartesianDomain({"x": [-1, 1]}) + temporal_domain = CartesianDomain({"t": [0, 1]}) + + domains = { + "D": CartesianDomain({"x": [-1, 1], "t": [0, 1]}), + "t0": CartesianDomain({"x": [-1, 1], "t": 0.0}), + } + + conditions = { + "D": Condition(domain="D", equation=Equation(allen_cahn_equation)), + "t0": Condition(domain="t0", equation=Equation(initial_condition)), + } diff --git a/pina/problem/zoo/diffusion_reaction.py b/pina/problem/zoo/diffusion_reaction.py index e7bc6c2be..6d6485dda 100644 --- a/pina/problem/zoo/diffusion_reaction.py +++ b/pina/problem/zoo/diffusion_reaction.py @@ -1,22 +1,26 @@ -"""Definition of the diffusion-reaction problem.""" +"""Formulation of the diffusion-reaction problem.""" import torch -from pina import Condition -from pina.problem import SpatialProblem, TimeDependentProblem -from pina.equation.equation import Equation -from pina.domain import CartesianDomain -from pina.operator import grad +from ... import Condition +from ...domain import CartesianDomain +from ...operator import grad, laplacian +from ...equation import Equation, FixedValue +from ...problem import SpatialProblem, TimeDependentProblem def diffusion_reaction(input_, output_): """ Implementation of the diffusion-reaction equation. + + :param LabelTensor input_: Input data of the problem. + :param LabelTensor output_: Output data of the problem. + :return: The residual of the diffusion-reaction equation. + :rtype: LabelTensor """ x = input_.extract("x") t = input_.extract("t") - u_t = grad(output_, input_, d="t") - u_x = grad(output_, input_, d="x") - u_xx = grad(u_x, input_, d="x") + u_t = grad(output_, input_, components=["u"], d=["t"]) + u_xx = laplacian(output_, input_, components=["u"], d=["x"]) r = torch.exp(-t) * ( 1.5 * torch.sin(2 * x) + (8 / 3) * torch.sin(3 * x) @@ -26,30 +30,72 @@ def diffusion_reaction(input_, output_): return u_t - u_xx - r -class DiffusionReactionProblem(TimeDependentProblem, SpatialProblem): +def initial_condition(input_, output_): + """ + Definition of the initial condition of the diffusion-reaction problem. + + :param LabelTensor input_: Input data of the problem. + :param LabelTensor output_: Output data of the problem. + :return: The residual of the initial condition. + :rtype: LabelTensor """ - Implementation of the diffusion-reaction problem on the spatial interval - [-pi, pi] and temporal interval [0,1]. + x = input_.extract("x") + u_0 = ( + torch.sin(x) + + (1 / 2) * torch.sin(2 * x) + + (1 / 3) * torch.sin(3 * x) + + (1 / 4) * torch.sin(4 * x) + + (1 / 8) * torch.sin(8 * x) + ) + return output_ - u_0 + + +class DiffusionReactionProblem(TimeDependentProblem, SpatialProblem): + r""" + Implementation of the diffusion-reaction problem in the spatial interval + :math:`[-\pi, \pi]` and temporal interval :math:`[0, 1]`. + + .. seealso:: + **Original reference**: Si, Chenhao, et al. *Complex Physics-Informed + Neural Network.* arXiv preprint arXiv:2502.04917 (2025). + DOI: `arXiv:2502.04917 `_. """ output_variables = ["u"] spatial_domain = CartesianDomain({"x": [-torch.pi, torch.pi]}) temporal_domain = CartesianDomain({"t": [0, 1]}) + domains = { + "D": CartesianDomain({"x": [-torch.pi, torch.pi], "t": [0, 1]}), + "g1": CartesianDomain({"x": -torch.pi, "t": [0, 1]}), + "g2": CartesianDomain({"x": torch.pi, "t": [0, 1]}), + "t0": CartesianDomain({"x": [-torch.pi, torch.pi], "t": 0.0}), + } + conditions = { - "D": Condition( - domain=CartesianDomain({"x": [-torch.pi, torch.pi], "t": [0, 1]}), - equation=Equation(diffusion_reaction), - ) + "D": Condition(domain="D", equation=Equation(diffusion_reaction)), + "g1": Condition(domain="g1", equation=FixedValue(0.0)), + "g2": Condition(domain="g2", equation=FixedValue(0.0)), + "t0": Condition(domain="t0", equation=Equation(initial_condition)), } - def _solution(self, pts): + def solution(self, pts): + """ + Implementation of the analytical solution of the diffusion-reaction + problem. + + :param LabelTensor pts: Points where the solution is evaluated. + :return: The analytical solution of the diffusion-reaction problem. + :rtype: LabelTensor + """ t = pts.extract("t") x = pts.extract("x") - return torch.exp(-t) * ( + sol = torch.exp(-t) * ( torch.sin(x) + (1 / 2) * torch.sin(2 * x) + (1 / 3) * torch.sin(3 * x) + (1 / 4) * torch.sin(4 * x) + (1 / 8) * torch.sin(8 * x) ) + sol.labels = self.output_variables + return sol diff --git a/pina/problem/zoo/helmholtz.py b/pina/problem/zoo/helmholtz.py new file mode 100644 index 000000000..8564d8200 --- /dev/null +++ b/pina/problem/zoo/helmholtz.py @@ -0,0 +1,104 @@ +"""Formulation of the Helmholtz problem.""" + +import torch +from ... import Condition +from ...operator import laplacian +from ...domain import CartesianDomain +from ...problem import SpatialProblem +from ...utils import check_consistency +from ...equation import Equation, FixedValue + + +class HelmholtzEquation(Equation): + """ + Implementation of the Helmholtz equation. + """ + + def __init__(self, alpha): + """ + Initialize the Helmholtz equation. + + :param alpha: Parameter of the forcing term. + :type alpha: float | int + """ + self.alpha = alpha + check_consistency(alpha, (int, float)) + + def equation(input_, output_): + """ + Implementation of the Helmholtz equation. + + :param LabelTensor input_: Input data of the problem. + :param LabelTensor output_: Output data of the problem. + :return: The residual of the Helmholtz equation. + :rtype: LabelTensor + """ + lap = laplacian(output_, input_, components=["u"], d=["x", "y"]) + q = ( + (1 - 2 * (self.alpha * torch.pi) ** 2) + * torch.sin(self.alpha * torch.pi * input_.extract("x")) + * torch.sin(self.alpha * torch.pi * input_.extract("y")) + ) + return lap + output_ - q + + super().__init__(equation) + + +class HelmholtzProblem(SpatialProblem): + r""" + Implementation of the Helmholtz problem in the square domain + :math:`[-1, 1] \times [-1, 1]`. + + .. seealso:: + **Original reference**: Si, Chenhao, et al. *Complex Physics-Informed + Neural Network.* arXiv preprint arXiv:2502.04917 (2025). + DOI: `arXiv:2502.04917 `_. + """ + + output_variables = ["u"] + spatial_domain = CartesianDomain({"x": [-1, 1], "y": [-1, 1]}) + + domains = { + "D": CartesianDomain({"x": [-1, 1], "y": [-1, 1]}), + "g1": CartesianDomain({"x": [-1, 1], "y": 1.0}), + "g2": CartesianDomain({"x": [-1, 1], "y": -1.0}), + "g3": CartesianDomain({"x": 1.0, "y": [-1, 1]}), + "g4": CartesianDomain({"x": -1.0, "y": [-1, 1]}), + } + + conditions = { + "g1": Condition(domain="g1", equation=FixedValue(0.0)), + "g2": Condition(domain="g2", equation=FixedValue(0.0)), + "g3": Condition(domain="g3", equation=FixedValue(0.0)), + "g4": Condition(domain="g4", equation=FixedValue(0.0)), + } + + def __init__(self, alpha=3.0): + """ + Initialize the Helmholtz problem. + + :param alpha: Parameter of the forcing term. + :type alpha: float | int + """ + super().__init__() + + self.alpha = alpha + check_consistency(alpha, (int, float)) + + self.conditions["D"] = Condition( + domain="D", equation=HelmholtzEquation(self.alpha) + ) + + def solution(self, pts): + """ + Implementation of the analytical solution of the Helmholtz problem. + + :param LabelTensor pts: Points where the solution is evaluated. + :return: The analytical solution of the Poisson problem. + :rtype: LabelTensor + """ + sol = torch.sin(self.alpha * torch.pi * pts.extract("x")) * torch.sin( + self.alpha * torch.pi * pts.extract("y") + ) + sol.labels = self.output_variables + return sol diff --git a/pina/problem/zoo/inverse_diffusion_reaction.py b/pina/problem/zoo/inverse_diffusion_reaction.py deleted file mode 100644 index 0a0560557..000000000 --- a/pina/problem/zoo/inverse_diffusion_reaction.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Definition of the diffusion-reaction problem.""" - -import torch -from pina import Condition, LabelTensor -from pina.problem import SpatialProblem, TimeDependentProblem, InverseProblem -from pina.equation.equation import Equation -from pina.domain import CartesianDomain -from pina.operator import grad - - -def diffusion_reaction(input_, output_): - """ - Implementation of the diffusion-reaction equation. - """ - x = input_.extract("x") - t = input_.extract("t") - u_t = grad(output_, input_, d="t") - u_x = grad(output_, input_, d="x") - u_xx = grad(u_x, input_, d="x") - r = torch.exp(-t) * ( - 1.5 * torch.sin(2 * x) - + (8 / 3) * torch.sin(3 * x) - + (15 / 4) * torch.sin(4 * x) - + (63 / 8) * torch.sin(8 * x) - ) - return u_t - u_xx - r - - -class InverseDiffusionReactionProblem( - TimeDependentProblem, SpatialProblem, InverseProblem -): - """ - Implementation of the diffusion-reaction inverse problem on the spatial - interval [-pi, pi] and temporal interval [0,1], with unknown parameters - in the interval [-1,1]. - """ - - output_variables = ["u"] - spatial_domain = CartesianDomain({"x": [-torch.pi, torch.pi]}) - temporal_domain = CartesianDomain({"t": [0, 1]}) - unknown_parameter_domain = CartesianDomain({"mu": [-1, 1]}) - - conditions = { - "D": Condition( - domain=CartesianDomain({"x": [-torch.pi, torch.pi], "t": [0, 1]}), - equation=Equation(diffusion_reaction), - ), - "data": Condition( - input=LabelTensor(torch.randn(10, 2), ["x", "t"]), - target=LabelTensor(torch.randn(10, 1), ["u"]), - ), - } - - def _solution(self, pts): - t = pts.extract("t") - x = pts.extract("x") - return torch.exp(-t) * ( - torch.sin(x) - + (1 / 2) * torch.sin(2 * x) - + (1 / 3) * torch.sin(3 * x) - + (1 / 4) * torch.sin(4 * x) - + (1 / 8) * torch.sin(8 * x) - ) diff --git a/pina/problem/zoo/inverse_poisson_2d_square.py b/pina/problem/zoo/inverse_poisson_2d_square.py index 2d9bbe5ac..16b4ec1d9 100644 --- a/pina/problem/zoo/inverse_poisson_2d_square.py +++ b/pina/problem/zoo/inverse_poisson_2d_square.py @@ -1,17 +1,23 @@ -"""Definition of the inverse Poisson problem on a square domain.""" +"""Formulation of the inverse Poisson problem in a square domain.""" +import os import torch -from pina import Condition, LabelTensor -from pina.problem import SpatialProblem, InverseProblem -from pina.operator import laplacian -from pina.domain import CartesianDomain -from pina.equation.equation import Equation -from pina.equation.equation_factory import FixedValue +from ... import Condition +from ...operator import laplacian +from ...domain import CartesianDomain +from ...equation import Equation, FixedValue +from ...problem import SpatialProblem, InverseProblem def laplace_equation(input_, output_, params_): """ Implementation of the laplace equation. + + :param LabelTensor input_: Input data of the problem. + :param LabelTensor output_: Output data of the problem. + :param dict params_: Parameters of the problem. + :return: The residual of the laplace equation. + :rtype: LabelTensor """ force_term = torch.exp( -2 * (input_.extract(["x"]) - params_["mu1"]) ** 2 @@ -21,17 +27,34 @@ def laplace_equation(input_, output_, params_): return delta_u - force_term +# Absolute path to the data directory +data_dir = os.path.abspath( + os.path.join( + os.path.dirname(__file__), "../../../tutorials/tutorial7/data/" + ) +) + +# Load input data +input_data = torch.load( + f=os.path.join(data_dir, "pts_0.5_0.5"), weights_only=False +).extract(["x", "y"]) + +# Load output data +output_data = torch.load( + f=os.path.join(data_dir, "pinn_solution_0.5_0.5"), weights_only=False +) + + class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem): - """ - Implementation of the inverse 2-dimensional Poisson problem - on a square domain, with parameter domain [-1, 1] x [-1, 1]. + r""" + Implementation of the inverse 2-dimensional Poisson problem in the square + domain :math:`[0, 1] \times [0, 1]`, + with unknown parameter domain :math:`[-1, 1] \times [-1, 1]`. """ output_variables = ["u"] x_min, x_max = -2, 2 y_min, y_max = -2, 2 - data_input = LabelTensor(torch.rand(10, 2), ["x", "y"]) - data_output = LabelTensor(torch.rand(10, 1), ["u"]) spatial_domain = CartesianDomain({"x": [x_min, x_max], "y": [y_min, y_max]}) unknown_parameter_domain = CartesianDomain({"mu1": [-1, 1], "mu2": [-1, 1]}) @@ -44,13 +67,10 @@ class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem): } conditions = { - "nil_g1": Condition(domain="g1", equation=FixedValue(0.0)), - "nil_g2": Condition(domain="g2", equation=FixedValue(0.0)), - "nil_g3": Condition(domain="g3", equation=FixedValue(0.0)), - "nil_g4": Condition(domain="g4", equation=FixedValue(0.0)), - "laplace_D": Condition(domain="D", equation=Equation(laplace_equation)), - "data": Condition( - input=data_input.extract(["x", "y"]), - target=data_output, - ), + "g1": Condition(domain="g1", equation=FixedValue(0.0)), + "g2": Condition(domain="g2", equation=FixedValue(0.0)), + "g3": Condition(domain="g3", equation=FixedValue(0.0)), + "g4": Condition(domain="g4", equation=FixedValue(0.0)), + "D": Condition(domain="D", equation=Equation(laplace_equation)), + "data": Condition(input=input_data, target=output_data), } diff --git a/pina/problem/zoo/poisson_2d_square.py b/pina/problem/zoo/poisson_2d_square.py index e65beb5bd..fef0b2e61 100644 --- a/pina/problem/zoo/poisson_2d_square.py +++ b/pina/problem/zoo/poisson_2d_square.py @@ -1,31 +1,35 @@ -"""Definition of the Poisson problem on a square domain.""" +"""Formulation of the Poisson problem in a square domain.""" import torch -from ..spatial_problem import SpatialProblem -from ...operator import laplacian from ... import Condition +from ...operator import laplacian +from ...problem import SpatialProblem from ...domain import CartesianDomain -from ...equation.equation import Equation -from ...equation.equation_factory import FixedValue +from ...equation import Equation, FixedValue def laplace_equation(input_, output_): """ Implementation of the laplace equation. + + :param LabelTensor input_: Input data of the problem. + :param LabelTensor output_: Output data of the problem. + :return: The residual of the laplace equation. + :rtype: LabelTensor """ - force_term = torch.sin(input_.extract(["x"]) * torch.pi) * torch.sin( - input_.extract(["y"]) * torch.pi + force_term = ( + torch.sin(input_.extract(["x"]) * torch.pi) + * torch.sin(input_.extract(["y"]) * torch.pi) + * (2 * torch.pi**2) ) - delta_u = laplacian(output_.extract(["u"]), input_) + delta_u = laplacian(output_, input_, components=["u"], d=["x", "y"]) return delta_u - force_term -my_laplace = Equation(laplace_equation) - - class Poisson2DSquareProblem(SpatialProblem): - """ - Implementation of the 2-dimensional Poisson problem on a square domain. + r""" + Implementation of the 2-dimensional Poisson problem in the square domain + :math:`[0, 1] \times [0, 1]`. """ output_variables = ["u"] @@ -33,24 +37,31 @@ class Poisson2DSquareProblem(SpatialProblem): domains = { "D": CartesianDomain({"x": [0, 1], "y": [0, 1]}), - "g1": CartesianDomain({"x": [0, 1], "y": 1}), - "g2": CartesianDomain({"x": [0, 1], "y": 0}), - "g3": CartesianDomain({"x": 1, "y": [0, 1]}), - "g4": CartesianDomain({"x": 0, "y": [0, 1]}), + "g1": CartesianDomain({"x": [0, 1], "y": 1.0}), + "g2": CartesianDomain({"x": [0, 1], "y": 0.0}), + "g3": CartesianDomain({"x": 1.0, "y": [0, 1]}), + "g4": CartesianDomain({"x": 0.0, "y": [0, 1]}), } conditions = { - "nil_g1": Condition(domain="g1", equation=FixedValue(0.0)), - "nil_g2": Condition(domain="g2", equation=FixedValue(0.0)), - "nil_g3": Condition(domain="g3", equation=FixedValue(0.0)), - "nil_g4": Condition(domain="g4", equation=FixedValue(0.0)), - "laplace_D": Condition(domain="D", equation=my_laplace), + "g1": Condition(domain="g1", equation=FixedValue(0.0)), + "g2": Condition(domain="g2", equation=FixedValue(0.0)), + "g3": Condition(domain="g3", equation=FixedValue(0.0)), + "g4": Condition(domain="g4", equation=FixedValue(0.0)), + "D": Condition(domain="D", equation=Equation(laplace_equation)), } - def poisson_sol(self, pts): - """TODO""" + def solution(self, pts): + """ + Implementation of the analytical solution of the Poisson problem. - return -( + :param LabelTensor pts: Points where the solution is evaluated. + :return: The analytical solution of the Poisson problem. + :rtype: LabelTensor + """ + sol = -( torch.sin(pts.extract(["x"]) * torch.pi) * torch.sin(pts.extract(["y"]) * torch.pi) ) + sol.labels = self.output_variables + return sol diff --git a/pina/problem/zoo/supervised_problem.py b/pina/problem/zoo/supervised_problem.py index 1d4654945..7e39a502a 100644 --- a/pina/problem/zoo/supervised_problem.py +++ b/pina/problem/zoo/supervised_problem.py @@ -1,4 +1,4 @@ -"""TODO""" +"""Formulation of a Supervised Problem in PINA.""" from ..abstract_problem import AbstractProblem from ... import Condition @@ -7,11 +7,11 @@ class SupervisedProblem(AbstractProblem): """ - A problem definition for supervised learning in PINA. + Definition of a supervised learning problem in PINA. - This class allows an easy and straightforward definition of a - Supervised problem, based on a single condition of type - `InputTargetCondition` + This class provides a simple way to define a supervised problem + using a single condition of type + :class:`~pina.condition.input_target_condition.InputTargetCondition`. :Example: >>> import torch @@ -25,12 +25,11 @@ class SupervisedProblem(AbstractProblem): def __init__(self, input_, output_): """ - Initialize the SupervisedProblem class + Initialize the SupervisedProblem class. - :param input_: Input data of the problem - :type input_: torch.Tensor | Graph - :param output_: Output data of the problem - :type output_: torch.Tensor + :param input_: Input data of the problem. + :param output_: Output data of the problem. + :type output_: torch.Tensor | Graph """ if isinstance(input_, Graph): input_ = input_.data diff --git a/tests/test_problem_zoo/test_advection.py b/tests/test_problem_zoo/test_advection.py new file mode 100644 index 000000000..4cfc27cd0 --- /dev/null +++ b/tests/test_problem_zoo/test_advection.py @@ -0,0 +1,18 @@ +import pytest +from pina.problem.zoo import AdvectionProblem +from pina.problem import SpatialProblem, TimeDependentProblem + + +@pytest.mark.parametrize("c", [1.5, 3]) +def test_constructor(c): + print(f"Testing with c = {c} (type: {type(c)})") + problem = AdvectionProblem(c=c) + problem.discretise_domain(n=10, mode="random", domains="all") + assert problem.are_all_domains_discretised + assert isinstance(problem, SpatialProblem) + assert isinstance(problem, TimeDependentProblem) + assert hasattr(problem, "conditions") + assert isinstance(problem.conditions, dict) + + with pytest.raises(ValueError): + AdvectionProblem(c="a") diff --git a/tests/test_problem_zoo/test_allen_cahn.py b/tests/test_problem_zoo/test_allen_cahn.py new file mode 100644 index 000000000..851348077 --- /dev/null +++ b/tests/test_problem_zoo/test_allen_cahn.py @@ -0,0 +1,12 @@ +from pina.problem.zoo import AllenCahnProblem +from pina.problem import SpatialProblem, TimeDependentProblem + + +def test_constructor(): + problem = AllenCahnProblem() + problem.discretise_domain(n=10, mode="random", domains="all") + assert problem.are_all_domains_discretised + assert isinstance(problem, SpatialProblem) + assert isinstance(problem, TimeDependentProblem) + assert hasattr(problem, "conditions") + assert isinstance(problem.conditions, dict) diff --git a/tests/test_problem_zoo/test_diffusion_reaction.py b/tests/test_problem_zoo/test_diffusion_reaction.py new file mode 100644 index 000000000..51709b29c --- /dev/null +++ b/tests/test_problem_zoo/test_diffusion_reaction.py @@ -0,0 +1,12 @@ +from pina.problem.zoo import DiffusionReactionProblem +from pina.problem import TimeDependentProblem, SpatialProblem + + +def test_constructor(): + problem = DiffusionReactionProblem() + problem.discretise_domain(n=10, mode="random", domains="all") + assert problem.are_all_domains_discretised + assert isinstance(problem, TimeDependentProblem) + assert isinstance(problem, SpatialProblem) + assert hasattr(problem, "conditions") + assert isinstance(problem.conditions, dict) diff --git a/tests/test_problem_zoo/test_helmholtz.py b/tests/test_problem_zoo/test_helmholtz.py new file mode 100644 index 000000000..ad8618a06 --- /dev/null +++ b/tests/test_problem_zoo/test_helmholtz.py @@ -0,0 +1,16 @@ +import pytest +from pina.problem.zoo import HelmholtzProblem +from pina.problem import SpatialProblem + + +@pytest.mark.parametrize("alpha", [1.5, 3]) +def test_constructor(alpha): + problem = HelmholtzProblem(alpha=alpha) + problem.discretise_domain(n=10, mode="random", domains="all") + assert problem.are_all_domains_discretised + assert isinstance(problem, SpatialProblem) + assert hasattr(problem, "conditions") + assert isinstance(problem.conditions, dict) + + with pytest.raises(ValueError): + HelmholtzProblem(alpha="a") diff --git a/tests/test_problem_zoo/test_inverse_poisson_2d_square.py b/tests/test_problem_zoo/test_inverse_poisson_2d_square.py new file mode 100644 index 000000000..20a60e636 --- /dev/null +++ b/tests/test_problem_zoo/test_inverse_poisson_2d_square.py @@ -0,0 +1,12 @@ +from pina.problem.zoo import InversePoisson2DSquareProblem +from pina.problem import InverseProblem, SpatialProblem + + +def test_constructor(): + problem = InversePoisson2DSquareProblem() + problem.discretise_domain(n=10, mode="random", domains="all") + assert problem.are_all_domains_discretised + assert isinstance(problem, InverseProblem) + assert isinstance(problem, SpatialProblem) + assert hasattr(problem, "conditions") + assert isinstance(problem.conditions, dict) diff --git a/tests/test_problem_zoo/test_poisson_2d_square.py b/tests/test_problem_zoo/test_poisson_2d_square.py index 272eb8c5a..ed7be0425 100644 --- a/tests/test_problem_zoo/test_poisson_2d_square.py +++ b/tests/test_problem_zoo/test_poisson_2d_square.py @@ -1,5 +1,11 @@ from pina.problem.zoo import Poisson2DSquareProblem +from pina.problem import SpatialProblem def test_constructor(): - Poisson2DSquareProblem() + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=10, mode="random", domains="all") + assert problem.are_all_domains_discretised + assert isinstance(problem, SpatialProblem) + assert hasattr(problem, "conditions") + assert isinstance(problem.conditions, dict) diff --git a/tests/test_solver/test_causal_pinn.py b/tests/test_solver/test_causal_pinn.py index 107502f8a..4e72732d3 100644 --- a/tests/test_solver/test_causal_pinn.py +++ b/tests/test_solver/test_causal_pinn.py @@ -6,10 +6,7 @@ from pina.solver import CausalPINN from pina.trainer import Trainer from pina.model import FeedForward -from pina.problem.zoo import ( - DiffusionReactionProblem, - InverseDiffusionReactionProblem, -) +from pina.problem.zoo import DiffusionReactionProblem from pina.condition import ( InputTargetCondition, InputEquationCondition, @@ -28,12 +25,9 @@ class DummySpatialProblem(SpatialProblem): spatial_domain = None -# define problems and model +# define problems problem = DiffusionReactionProblem() problem.discretise_domain(50) -inverse_problem = InverseDiffusionReactionProblem() -inverse_problem.discretise_domain(50) -model = FeedForward(len(problem.input_variables), len(problem.output_variables)) # add input-output condition to test supervised learning input_pts = torch.rand(50, len(problem.input_variables)) @@ -42,8 +36,11 @@ class DummySpatialProblem(SpatialProblem): output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions["data"] = Condition(input=input_pts, target=output_pts) +# define model +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + -@pytest.mark.parametrize("problem", [problem, inverse_problem]) +@pytest.mark.parametrize("problem", [problem]) @pytest.mark.parametrize("eps", [100, 100.1]) def test_constructor(problem, eps): with pytest.raises(ValueError): @@ -57,7 +54,7 @@ def test_constructor(problem, eps): ) -@pytest.mark.parametrize("problem", [problem, inverse_problem]) +@pytest.mark.parametrize("problem", [problem]) @pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) @pytest.mark.parametrize("compile", [True, False]) def test_solver_train(problem, batch_size, compile): @@ -77,7 +74,7 @@ def test_solver_train(problem, batch_size, compile): assert isinstance(solver.model, OptimizedModule) -@pytest.mark.parametrize("problem", [problem, inverse_problem]) +@pytest.mark.parametrize("problem", [problem]) @pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) @pytest.mark.parametrize("compile", [True, False]) def test_solver_validation(problem, batch_size, compile): @@ -97,7 +94,7 @@ def test_solver_validation(problem, batch_size, compile): assert isinstance(solver.model, OptimizedModule) -@pytest.mark.parametrize("problem", [problem, inverse_problem]) +@pytest.mark.parametrize("problem", [problem]) @pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) @pytest.mark.parametrize("compile", [True, False]) def test_solver_test(problem, batch_size, compile): @@ -117,7 +114,7 @@ def test_solver_test(problem, batch_size, compile): assert isinstance(solver.model, OptimizedModule) -@pytest.mark.parametrize("problem", [problem, inverse_problem]) +@pytest.mark.parametrize("problem", [problem]) def test_train_load_restore(problem): dir = "tests/test_solver/tmp" problem = problem diff --git a/tests/test_solver/test_competitive_pinn.py b/tests/test_solver/test_competitive_pinn.py index c5f8017a2..64fb28058 100644 --- a/tests/test_solver/test_competitive_pinn.py +++ b/tests/test_solver/test_competitive_pinn.py @@ -17,12 +17,16 @@ from torch._dynamo.eval_frame import OptimizedModule -# define problems and model +# define problems problem = Poisson() problem.discretise_domain(50) inverse_problem = InversePoisson() inverse_problem.discretise_domain(50) -model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + +# reduce the number of data points to speed up testing +data_condition = inverse_problem.conditions["data"] +data_condition.input = data_condition.input[:10] +data_condition.target = data_condition.target[:10] # add input-output condition to test supervised learning input_pts = torch.rand(50, len(problem.input_variables)) @@ -31,6 +35,9 @@ output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions["data"] = Condition(input=input_pts, target=output_pts) +# define model +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + @pytest.mark.parametrize("problem", [problem, inverse_problem]) @pytest.mark.parametrize("discr", [None, model]) diff --git a/tests/test_solver/test_gradient_pinn.py b/tests/test_solver/test_gradient_pinn.py index c572036ea..31666db3d 100644 --- a/tests/test_solver/test_gradient_pinn.py +++ b/tests/test_solver/test_gradient_pinn.py @@ -28,12 +28,16 @@ class DummyTimeProblem(TimeDependentProblem): conditions = {} -# define problems and model +# define problems problem = Poisson() problem.discretise_domain(50) inverse_problem = InversePoisson() inverse_problem.discretise_domain(50) -model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + +# reduce the number of data points to speed up testing +data_condition = inverse_problem.conditions["data"] +data_condition.input = data_condition.input[:10] +data_condition.target = data_condition.target[:10] # add input-output condition to test supervised learning input_pts = torch.rand(50, len(problem.input_variables)) @@ -42,6 +46,9 @@ class DummyTimeProblem(TimeDependentProblem): output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions["data"] = Condition(input=input_pts, target=output_pts) +# define model +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + @pytest.mark.parametrize("problem", [problem, inverse_problem]) def test_constructor(problem): diff --git a/tests/test_solver/test_pinn.py b/tests/test_solver/test_pinn.py index 98d14389e..97511cb14 100644 --- a/tests/test_solver/test_pinn.py +++ b/tests/test_solver/test_pinn.py @@ -17,12 +17,16 @@ from torch._dynamo.eval_frame import OptimizedModule -# define problems and model +# define problems problem = Poisson() problem.discretise_domain(50) inverse_problem = InversePoisson() inverse_problem.discretise_domain(50) -model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + +# reduce the number of data points to speed up testing +data_condition = inverse_problem.conditions["data"] +data_condition.input = data_condition.input[:10] +data_condition.target = data_condition.target[:10] # add input-output condition to test supervised learning input_pts = torch.rand(50, len(problem.input_variables)) @@ -31,6 +35,9 @@ output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions["data"] = Condition(input=input_pts, target=output_pts) +# define model +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + @pytest.mark.parametrize("problem", [problem, inverse_problem]) def test_constructor(problem): diff --git a/tests/test_solver/test_rba_pinn.py b/tests/test_solver/test_rba_pinn.py index ba74eba91..f355aab02 100644 --- a/tests/test_solver/test_rba_pinn.py +++ b/tests/test_solver/test_rba_pinn.py @@ -16,12 +16,16 @@ ) from torch._dynamo.eval_frame import OptimizedModule -# define problems and model +# define problems problem = Poisson() problem.discretise_domain(50) inverse_problem = InversePoisson() inverse_problem.discretise_domain(50) -model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + +# reduce the number of data points to speed up testing +data_condition = inverse_problem.conditions["data"] +data_condition.input = data_condition.input[:10] +data_condition.target = data_condition.target[:10] # add input-output condition to test supervised learning input_pts = torch.rand(50, len(problem.input_variables)) @@ -30,6 +34,9 @@ output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions["data"] = Condition(input=input_pts, target=output_pts) +# define model +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + @pytest.mark.parametrize("problem", [problem, inverse_problem]) @pytest.mark.parametrize("eta", [1, 0.001]) diff --git a/tests/test_solver/test_self_adaptive_pinn.py b/tests/test_solver/test_self_adaptive_pinn.py index b42472df5..48e3d9f8b 100644 --- a/tests/test_solver/test_self_adaptive_pinn.py +++ b/tests/test_solver/test_self_adaptive_pinn.py @@ -17,12 +17,16 @@ from torch._dynamo.eval_frame import OptimizedModule -# make the problem +# define problems problem = Poisson() problem.discretise_domain(50) inverse_problem = InversePoisson() inverse_problem.discretise_domain(50) -model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + +# reduce the number of data points to speed up testing +data_condition = inverse_problem.conditions["data"] +data_condition.input = data_condition.input[:10] +data_condition.target = data_condition.target[:10] # add input-output condition to test supervised learning input_pts = torch.rand(50, len(problem.input_variables)) @@ -31,6 +35,9 @@ output_pts = LabelTensor(output_pts, problem.output_variables) problem.conditions["data"] = Condition(input=input_pts, target=output_pts) +# define model +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + @pytest.mark.parametrize("problem", [problem, inverse_problem]) @pytest.mark.parametrize("weight_fn", [torch.nn.Sigmoid(), torch.nn.Tanh()])