Skip to content

Commit ef79146

Browse files
Adding new problems to problem.zoo (#484)
* adding problems * add tests * update doc + formatting --------- Co-authored-by: Dario Coscia <[email protected]>
1 parent 7886c38 commit ef79146

21 files changed

+569
-167
lines changed

pina/problem/zoo/__init__.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
"""TODO"""
22

33
__all__ = [
4-
"Poisson2DSquareProblem",
54
"SupervisedProblem",
6-
"InversePoisson2DSquareProblem",
5+
"HelmholtzProblem",
6+
"AllenCahnProblem",
7+
"AdvectionProblem",
8+
"Poisson2DSquareProblem",
79
"DiffusionReactionProblem",
8-
"InverseDiffusionReactionProblem",
10+
"InversePoisson2DSquareProblem",
911
]
1012

11-
from .poisson_2d_square import Poisson2DSquareProblem
1213
from .supervised_problem import SupervisedProblem
13-
from .inverse_poisson_2d_square import InversePoisson2DSquareProblem
14+
from .helmholtz import HelmholtzProblem
15+
from .allen_cahn import AllenCahnProblem
16+
from .advection import AdvectionProblem
17+
from .poisson_2d_square import Poisson2DSquareProblem
1418
from .diffusion_reaction import DiffusionReactionProblem
15-
from .inverse_diffusion_reaction import InverseDiffusionReactionProblem
19+
from .inverse_poisson_2d_square import InversePoisson2DSquareProblem

pina/problem/zoo/advection.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""Formulation of the advection problem."""
2+
3+
import torch
4+
from ... import Condition
5+
from ...operator import grad
6+
from ...equation import Equation
7+
from ...domain import CartesianDomain
8+
from ...utils import check_consistency
9+
from ...problem import SpatialProblem, TimeDependentProblem
10+
11+
12+
class AdvectionEquation(Equation):
13+
"""
14+
Implementation of the advection equation.
15+
"""
16+
17+
def __init__(self, c):
18+
"""
19+
Initialize the advection equation.
20+
21+
:param c: The advection velocity parameter.
22+
:type c: float | int
23+
"""
24+
self.c = c
25+
check_consistency(self.c, (float, int))
26+
27+
def equation(input_, output_):
28+
"""
29+
Implementation of the advection equation.
30+
31+
:param LabelTensor input_: Input data of the problem.
32+
:param LabelTensor output_: Output data of the problem.
33+
:return: The residual of the advection equation.
34+
:rtype: LabelTensor
35+
"""
36+
u_x = grad(output_, input_, components=["u"], d=["x"])
37+
u_t = grad(output_, input_, components=["u"], d=["t"])
38+
return u_t + self.c * u_x
39+
40+
super().__init__(equation)
41+
42+
43+
def initial_condition(input_, output_):
44+
"""
45+
Implementation of the initial condition.
46+
47+
:param LabelTensor input_: Input data of the problem.
48+
:param LabelTensor output_: Output data of the problem.
49+
:return: The residual of the initial condition.
50+
:rtype: LabelTensor
51+
"""
52+
return output_ - torch.sin(input_.extract("x"))
53+
54+
55+
class AdvectionProblem(SpatialProblem, TimeDependentProblem):
56+
r"""
57+
Implementation of the advection problem in the spatial interval
58+
:math:`[0, 2 \pi]` and temporal interval :math:`[0, 1]`.
59+
60+
.. seealso::
61+
62+
**Original reference**: Wang, Sifan, et al. *An expert's guide to
63+
training physics-informed neural networks*.
64+
arXiv preprint arXiv:2308.08468 (2023).
65+
DOI: `arXiv:2308.08468 <https://arxiv.org/abs/2308.08468>`_.
66+
"""
67+
68+
output_variables = ["u"]
69+
spatial_domain = CartesianDomain({"x": [0, 2 * torch.pi]})
70+
temporal_domain = CartesianDomain({"t": [0, 1]})
71+
72+
domains = {
73+
"D": CartesianDomain({"x": [0, 2 * torch.pi], "t": [0, 1]}),
74+
"t0": CartesianDomain({"x": [0, 2 * torch.pi], "t": 0.0}),
75+
}
76+
77+
conditions = {
78+
"t0": Condition(domain="t0", equation=Equation(initial_condition)),
79+
}
80+
81+
def __init__(self, c=1.0):
82+
"""
83+
Initialize the advection problem.
84+
85+
:param c: The advection velocity parameter.
86+
:type c: float | int
87+
"""
88+
super().__init__()
89+
90+
self.c = c
91+
check_consistency(self.c, (float, int))
92+
93+
self.conditions["D"] = Condition(
94+
domain="D", equation=AdvectionEquation(self.c)
95+
)
96+
97+
def solution(self, pts):
98+
"""
99+
Implementation of the analytical solution of the advection problem.
100+
101+
:param LabelTensor pts: Points where the solution is evaluated.
102+
:return: The analytical solution of the advection problem.
103+
:rtype: LabelTensor
104+
"""
105+
sol = torch.sin(pts.extract("x") - self.c * pts.extract("t"))
106+
sol.labels = self.output_variables
107+
return sol

pina/problem/zoo/allen_cahn.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Formulation of the Allen Cahn problem."""
2+
3+
import torch
4+
from ... import Condition
5+
from ...equation import Equation
6+
from ...domain import CartesianDomain
7+
from ...operator import grad, laplacian
8+
from ...problem import SpatialProblem, TimeDependentProblem
9+
10+
11+
def allen_cahn_equation(input_, output_):
12+
"""
13+
Implementation of the Allen Cahn equation.
14+
15+
:param LabelTensor input_: Input data of the problem.
16+
:param LabelTensor output_: Output data of the problem.
17+
:return: The residual of the Allen Cahn equation.
18+
:rtype: LabelTensor
19+
"""
20+
u_t = grad(output_, input_, components=["u"], d=["t"])
21+
u_xx = laplacian(output_, input_, components=["u"], d=["x"])
22+
return u_t - 0.0001 * u_xx + 5 * output_**3 - 5 * output_
23+
24+
25+
def initial_condition(input_, output_):
26+
"""
27+
Definition of the initial condition of the Allen Cahn problem.
28+
29+
:param LabelTensor input_: Input data of the problem.
30+
:param LabelTensor output_: Output data of the problem.
31+
:return: The residual of the initial condition.
32+
:rtype: LabelTensor
33+
"""
34+
x = input_.extract("x")
35+
u_0 = x**2 * torch.cos(torch.pi * x)
36+
return output_ - u_0
37+
38+
39+
class AllenCahnProblem(TimeDependentProblem, SpatialProblem):
40+
r"""
41+
Implementation of the Allen Cahn problem in the spatial interval
42+
:math:`[-1, 1]` and temporal interval :math:`[0, 1]`.
43+
44+
.. seealso::
45+
**Original reference**: Sokratis J. Anagnostopoulos, Juan D. Toscano,
46+
Nikolaos Stergiopulos, and George E. Karniadakis.
47+
*Residual-based attention and connection to information
48+
bottleneck theory in PINNs*.
49+
Computer Methods in Applied Mechanics and Engineering 421 (2024): 116805
50+
DOI: `10.1016/
51+
j.cma.2024.116805 <https://doi.org/10.1016/j.cma.2024.116805>`_.
52+
"""
53+
54+
output_variables = ["u"]
55+
spatial_domain = CartesianDomain({"x": [-1, 1]})
56+
temporal_domain = CartesianDomain({"t": [0, 1]})
57+
58+
domains = {
59+
"D": CartesianDomain({"x": [-1, 1], "t": [0, 1]}),
60+
"t0": CartesianDomain({"x": [-1, 1], "t": 0.0}),
61+
}
62+
63+
conditions = {
64+
"D": Condition(domain="D", equation=Equation(allen_cahn_equation)),
65+
"t0": Condition(domain="t0", equation=Equation(initial_condition)),
66+
}
Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
1-
"""Definition of the diffusion-reaction problem."""
1+
"""Formulation of the diffusion-reaction problem."""
22

33
import torch
4-
from pina import Condition
5-
from pina.problem import SpatialProblem, TimeDependentProblem
6-
from pina.equation.equation import Equation
7-
from pina.domain import CartesianDomain
8-
from pina.operator import grad
4+
from ... import Condition
5+
from ...domain import CartesianDomain
6+
from ...operator import grad, laplacian
7+
from ...equation import Equation, FixedValue
8+
from ...problem import SpatialProblem, TimeDependentProblem
99

1010

1111
def diffusion_reaction(input_, output_):
1212
"""
1313
Implementation of the diffusion-reaction equation.
14+
15+
:param LabelTensor input_: Input data of the problem.
16+
:param LabelTensor output_: Output data of the problem.
17+
:return: The residual of the diffusion-reaction equation.
18+
:rtype: LabelTensor
1419
"""
1520
x = input_.extract("x")
1621
t = input_.extract("t")
17-
u_t = grad(output_, input_, d="t")
18-
u_x = grad(output_, input_, d="x")
19-
u_xx = grad(u_x, input_, d="x")
22+
u_t = grad(output_, input_, components=["u"], d=["t"])
23+
u_xx = laplacian(output_, input_, components=["u"], d=["x"])
2024
r = torch.exp(-t) * (
2125
1.5 * torch.sin(2 * x)
2226
+ (8 / 3) * torch.sin(3 * x)
@@ -26,30 +30,72 @@ def diffusion_reaction(input_, output_):
2630
return u_t - u_xx - r
2731

2832

29-
class DiffusionReactionProblem(TimeDependentProblem, SpatialProblem):
33+
def initial_condition(input_, output_):
34+
"""
35+
Definition of the initial condition of the diffusion-reaction problem.
36+
37+
:param LabelTensor input_: Input data of the problem.
38+
:param LabelTensor output_: Output data of the problem.
39+
:return: The residual of the initial condition.
40+
:rtype: LabelTensor
3041
"""
31-
Implementation of the diffusion-reaction problem on the spatial interval
32-
[-pi, pi] and temporal interval [0,1].
42+
x = input_.extract("x")
43+
u_0 = (
44+
torch.sin(x)
45+
+ (1 / 2) * torch.sin(2 * x)
46+
+ (1 / 3) * torch.sin(3 * x)
47+
+ (1 / 4) * torch.sin(4 * x)
48+
+ (1 / 8) * torch.sin(8 * x)
49+
)
50+
return output_ - u_0
51+
52+
53+
class DiffusionReactionProblem(TimeDependentProblem, SpatialProblem):
54+
r"""
55+
Implementation of the diffusion-reaction problem in the spatial interval
56+
:math:`[-\pi, \pi]` and temporal interval :math:`[0, 1]`.
57+
58+
.. seealso::
59+
**Original reference**: Si, Chenhao, et al. *Complex Physics-Informed
60+
Neural Network.* arXiv preprint arXiv:2502.04917 (2025).
61+
DOI: `arXiv:2502.04917 <https://arxiv.org/abs/2502.04917>`_.
3362
"""
3463

3564
output_variables = ["u"]
3665
spatial_domain = CartesianDomain({"x": [-torch.pi, torch.pi]})
3766
temporal_domain = CartesianDomain({"t": [0, 1]})
3867

68+
domains = {
69+
"D": CartesianDomain({"x": [-torch.pi, torch.pi], "t": [0, 1]}),
70+
"g1": CartesianDomain({"x": -torch.pi, "t": [0, 1]}),
71+
"g2": CartesianDomain({"x": torch.pi, "t": [0, 1]}),
72+
"t0": CartesianDomain({"x": [-torch.pi, torch.pi], "t": 0.0}),
73+
}
74+
3975
conditions = {
40-
"D": Condition(
41-
domain=CartesianDomain({"x": [-torch.pi, torch.pi], "t": [0, 1]}),
42-
equation=Equation(diffusion_reaction),
43-
)
76+
"D": Condition(domain="D", equation=Equation(diffusion_reaction)),
77+
"g1": Condition(domain="g1", equation=FixedValue(0.0)),
78+
"g2": Condition(domain="g2", equation=FixedValue(0.0)),
79+
"t0": Condition(domain="t0", equation=Equation(initial_condition)),
4480
}
4581

46-
def _solution(self, pts):
82+
def solution(self, pts):
83+
"""
84+
Implementation of the analytical solution of the diffusion-reaction
85+
problem.
86+
87+
:param LabelTensor pts: Points where the solution is evaluated.
88+
:return: The analytical solution of the diffusion-reaction problem.
89+
:rtype: LabelTensor
90+
"""
4791
t = pts.extract("t")
4892
x = pts.extract("x")
49-
return torch.exp(-t) * (
93+
sol = torch.exp(-t) * (
5094
torch.sin(x)
5195
+ (1 / 2) * torch.sin(2 * x)
5296
+ (1 / 3) * torch.sin(3 * x)
5397
+ (1 / 4) * torch.sin(4 * x)
5498
+ (1 / 8) * torch.sin(8 * x)
5599
)
100+
sol.labels = self.output_variables
101+
return sol

0 commit comments

Comments
 (0)