Skip to content

Commit a850576

Browse files
fix bug + add training tests
1 parent 256ac9d commit a850576

File tree

9 files changed

+111
-13
lines changed

9 files changed

+111
-13
lines changed

pina/equation/equation_factory.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,12 @@ def equation(input_, output_):
246246
)
247247

248248
# Repeat c to ensure consistent shape for advection
249-
self.c = self.c.repeat(output_.shape[0], 1)
250-
if self.c.shape[1] != (len(input_lbl) - 1):
251-
self.c = self.c.repeat(1, len(input_lbl) - 1)
249+
c = self.c.repeat(output_.shape[0], 1).to(input_.device)
250+
if c.shape[1] != (len(input_lbl) - 1):
251+
c = c.repeat(1, len(input_lbl) - 1)
252252

253253
# Add a dimension to c for the following operations
254-
self.c = self.c.unsqueeze(-1)
254+
c = c.unsqueeze(-1)
255255

256256
# Compute the time derivative and the spatial gradient
257257
time_der = grad(output_, input_, components=None, d="t")
@@ -262,7 +262,7 @@ def equation(input_, output_):
262262
tmp = tmp.transpose(-1, -2)
263263

264264
# Compute advection term
265-
adv = (tmp * self.c).sum(dim=tmp.tensor.ndim - 2)
265+
adv = (tmp * c).sum(dim=tmp.tensor.ndim - 2)
266266

267267
return time_der + adv
268268

pina/problem/zoo/helmholtz.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,10 @@ def __init__(self, alpha=3.0):
4848
:type alpha: float | int
4949
"""
5050
super().__init__()
51-
52-
self.alpha = alpha
5351
check_consistency(alpha, (int, float))
52+
self.alpha = alpha
5453

55-
def forcing_term(self, input_):
54+
def forcing_term(input_):
5655
"""
5756
Implementation of the forcing term.
5857
"""

tests/test_problem_zoo/test_advection.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import pytest
2-
from pina.problem.zoo import AdvectionProblem
32
from pina.problem import SpatialProblem, TimeDependentProblem
3+
from pina.problem.zoo import AdvectionProblem
4+
from pina.model import FeedForward
5+
from pina.solver import PINN
6+
from pina import Trainer
47

58

69
@pytest.mark.parametrize("c", [1.5, 3])
@@ -17,3 +20,14 @@ def test_constructor(c):
1720
# Should fail if c is not a float or int
1821
with pytest.raises(ValueError):
1922
AdvectionProblem(c="invalid")
23+
24+
25+
@pytest.mark.parametrize("c", [1.5, 3])
26+
def test_train(c):
27+
28+
problem = AdvectionProblem(c=c)
29+
problem.discretise_domain(n=10)
30+
model = FeedForward(2, 1, 10, 2)
31+
solver = PINN(problem, model)
32+
trainer = Trainer(solver, max_epochs=2)
33+
trainer.train()

tests/test_problem_zoo/test_allen_cahn.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import pytest
2-
from pina.problem.zoo import AllenCahnProblem
32
from pina.problem import SpatialProblem, TimeDependentProblem
3+
from pina.problem.zoo import AllenCahnProblem
4+
from pina.problem import SpatialProblem
5+
from pina.model import FeedForward
6+
from pina.solver import PINN
7+
from pina import Trainer
48

59

610
@pytest.mark.parametrize("alpha", [0.1, 1])
@@ -22,3 +26,15 @@ def test_constructor(alpha, beta):
2226
# Should fail if beta is not a float or int
2327
with pytest.raises(ValueError):
2428
AllenCahnProblem(alpha=alpha, beta="invalid")
29+
30+
31+
@pytest.mark.parametrize("alpha", [0.1, 1])
32+
@pytest.mark.parametrize("beta", [0.1, 1])
33+
def test_train(alpha, beta):
34+
35+
problem = AllenCahnProblem(alpha=alpha, beta=beta)
36+
problem.discretise_domain(n=10)
37+
model = FeedForward(2, 1, 10, 2)
38+
solver = PINN(problem, model)
39+
trainer = Trainer(solver, max_epochs=2)
40+
trainer.train()

tests/test_problem_zoo/test_diffusion_reaction.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import pytest
2-
from pina.problem.zoo import DiffusionReactionProblem
32
from pina.problem import TimeDependentProblem, SpatialProblem
3+
from pina.problem.zoo import DiffusionReactionProblem
4+
from pina.problem import SpatialProblem
5+
from pina.model import FeedForward
6+
from pina.solver import PINN
7+
from pina import Trainer
48

59

610
@pytest.mark.parametrize("alpha", [0.1, 1])
@@ -17,3 +21,14 @@ def test_constructor(alpha):
1721
# Should fail if alpha is not a float or int
1822
with pytest.raises(ValueError):
1923
problem = DiffusionReactionProblem(alpha="invalid")
24+
25+
26+
@pytest.mark.parametrize("alpha", [0.1, 1])
27+
def test_train(alpha):
28+
29+
problem = DiffusionReactionProblem(alpha=alpha)
30+
problem.discretise_domain(n=10)
31+
model = FeedForward(2, 1, 10, 2)
32+
solver = PINN(problem, model)
33+
trainer = Trainer(solver, max_epochs=2)
34+
trainer.train()

tests/test_problem_zoo/test_helmholtz.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import pytest
22
from pina.problem.zoo import HelmholtzProblem
33
from pina.problem import SpatialProblem
4+
from pina.model import FeedForward
5+
from pina.solver import PINN
6+
from pina import Trainer
47

58

69
@pytest.mark.parametrize("alpha", [1.5, 3])
@@ -15,3 +18,14 @@ def test_constructor(alpha):
1518

1619
with pytest.raises(ValueError):
1720
HelmholtzProblem(alpha="invalid")
21+
22+
23+
@pytest.mark.parametrize("alpha", [1.5, 3])
24+
def test_train(alpha):
25+
26+
problem = HelmholtzProblem(alpha=alpha)
27+
problem.discretise_domain(n=10)
28+
model = FeedForward(2, 1, 10, 2)
29+
solver = PINN(problem, model)
30+
trainer = Trainer(solver, max_epochs=2, accelerator="cpu")
31+
trainer.train()

tests/test_problem_zoo/test_inverse_poisson_2d_square.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import pytest
22
from pina.problem.zoo import InversePoisson2DSquareProblem
33
from pina.problem import InverseProblem, SpatialProblem
4+
from pina.model import FeedForward
5+
from pina.solver import PINN
6+
from pina import Trainer
47

58

69
@pytest.mark.parametrize("load", [True, False])
@@ -23,3 +26,13 @@ def test_constructor(load, data_size):
2326
# Should fail if data_size is not in the range [0.0, 1.0]
2427
with pytest.raises(ValueError):
2528
problem = InversePoisson2DSquareProblem(load=load, data_size=3.0)
29+
30+
31+
def test_train():
32+
33+
problem = InversePoisson2DSquareProblem()
34+
problem.discretise_domain(n=10)
35+
model = FeedForward(2, 1, 10, 2)
36+
solver = PINN(problem, model)
37+
trainer = Trainer(solver, max_epochs=2, accelerator="gpu")
38+
trainer.train()

tests/test_problem_zoo/test_poisson_2d_square.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from pina.problem.zoo import Poisson2DSquareProblem
22
from pina.problem import SpatialProblem
3+
from pina.model import FeedForward
4+
from pina.solver import PINN
5+
from pina import Trainer
36

47

58
def test_constructor():
@@ -10,3 +13,13 @@ def test_constructor():
1013
assert isinstance(problem, SpatialProblem)
1114
assert hasattr(problem, "conditions")
1215
assert isinstance(problem.conditions, dict)
16+
17+
18+
def test_train():
19+
20+
problem = Poisson2DSquareProblem()
21+
problem.discretise_domain(n=10)
22+
model = FeedForward(2, 1, 10, 2)
23+
solver = PINN(problem, model)
24+
trainer = Trainer(solver, max_epochs=2)
25+
trainer.train()

tests/test_problem_zoo/test_supervised_problem.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import torch
2-
from pina.problem import AbstractProblem
3-
from pina.condition import InputTargetCondition
42
from pina.problem.zoo.supervised_problem import SupervisedProblem
3+
from pina.condition import InputTargetCondition
4+
from pina.problem import AbstractProblem
5+
from pina.solver import SupervisedSolver
56
from pina.graph import RadiusGraph
7+
from pina.model import FeedForward
8+
from pina import Trainer
69

710

811
def test_constructor():
@@ -32,3 +35,14 @@ def test_constructor_graph():
3235
assert isinstance(problem.conditions["data"], InputTargetCondition)
3336
assert isinstance(problem.conditions["data"].input, list)
3437
assert isinstance(problem.conditions["data"].target, torch.Tensor)
38+
39+
40+
def test_train():
41+
42+
input_ = torch.rand((100, 10))
43+
output_ = torch.rand((100, 10))
44+
problem = SupervisedProblem(input_=input_, output_=output_)
45+
model = FeedForward(10, 10, 10, 2)
46+
solver = SupervisedSolver(problem, model, use_lt=False)
47+
trainer = Trainer(solver, max_epochs=2, accelerator="cpu")
48+
trainer.train()

0 commit comments

Comments
 (0)