Skip to content

Commit 9be5a27

Browse files
update problem zoo and tests post domain refactory
1 parent acef631 commit 9be5a27

File tree

13 files changed

+40
-65
lines changed

13 files changed

+40
-65
lines changed

pina/problem/abstract_problem.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,9 @@ def collect_data(self):
328328
# Only store the discretisation points if the domain is
329329
# in the dictionary
330330
if condition.domain in self.discretised_domains:
331-
samples = self.discretised_domains[condition.domain]
331+
samples = self.discretised_domains[condition.domain][
332+
self.input_variables
333+
]
332334
data[condition_name] = {
333335
"input": samples,
334336
"equation": condition.equation,

pina/problem/zoo/acoustic_wave.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,17 @@ class AcousticWaveProblem(TimeDependentProblem, SpatialProblem):
5050
temporal_domain = CartesianDomain({"t": [0, 1]})
5151

5252
domains = {
53-
"D": CartesianDomain({"x": [0, 1], "t": [0, 1]}),
54-
"t0": CartesianDomain({"x": [0, 1], "t": 0.0}),
55-
"g1": CartesianDomain({"x": 0.0, "t": [0, 1]}),
56-
"g2": CartesianDomain({"x": 1.0, "t": [0, 1]}),
53+
"D": spatial_domain.update(temporal_domain),
54+
"t0": spatial_domain.update(CartesianDomain({"t": 0})),
55+
"boundary": spatial_domain.partial().update(temporal_domain),
5756
}
5857

5958
conditions = {
60-
"g1": Condition(domain="g1", equation=FixedValue(value=0.0)),
61-
"g2": Condition(domain="g2", equation=FixedValue(value=0.0)),
59+
"boundary": Condition(domain="boundary", equation=FixedValue(0.0)),
6260
"t0": Condition(
6361
domain="t0",
6462
equation=SystemEquation(
65-
[Equation(initial_condition), FixedGradient(value=0.0, d="t")]
63+
[Equation(initial_condition), FixedGradient(0.0, d="t")]
6664
),
6765
),
6866
}

pina/problem/zoo/advection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class AdvectionProblem(SpatialProblem, TimeDependentProblem):
4242
temporal_domain = CartesianDomain({"t": [0, 1]})
4343

4444
domains = {
45-
"D": CartesianDomain({"x": [0, 2 * torch.pi], "t": [0, 1]}),
46-
"t0": CartesianDomain({"x": [0, 2 * torch.pi], "t": 0.0}),
45+
"D": spatial_domain.update(temporal_domain),
46+
"t0": spatial_domain.update(CartesianDomain({"t": 0})),
4747
}
4848

4949
conditions = {

pina/problem/zoo/allen_cahn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ class AllenCahnProblem(TimeDependentProblem, SpatialProblem):
4747
temporal_domain = CartesianDomain({"t": [0, 1]})
4848

4949
domains = {
50-
"D": CartesianDomain({"x": [-1, 1], "t": [0, 1]}),
51-
"t0": CartesianDomain({"x": [-1, 1], "t": 0.0}),
50+
"D": spatial_domain.update(temporal_domain),
51+
"t0": spatial_domain.update(CartesianDomain({"t": 0})),
5252
}
5353

5454
conditions = {

pina/problem/zoo/diffusion_reaction.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,13 @@ class DiffusionReactionProblem(TimeDependentProblem, SpatialProblem):
4949
temporal_domain = CartesianDomain({"t": [0, 1]})
5050

5151
domains = {
52-
"D": CartesianDomain({"x": [-torch.pi, torch.pi], "t": [0, 1]}),
53-
"g1": CartesianDomain({"x": -torch.pi, "t": [0, 1]}),
54-
"g2": CartesianDomain({"x": torch.pi, "t": [0, 1]}),
55-
"t0": CartesianDomain({"x": [-torch.pi, torch.pi], "t": 0.0}),
52+
"D": spatial_domain.update(temporal_domain),
53+
"boundary": spatial_domain.partial().update(temporal_domain),
54+
"t0": spatial_domain.update(CartesianDomain({"t": 0})),
5655
}
5756

5857
conditions = {
59-
"g1": Condition(domain="g1", equation=FixedValue(0.0)),
60-
"g2": Condition(domain="g2", equation=FixedValue(0.0)),
58+
"boundary": Condition(domain="boundary", equation=FixedValue(0.0)),
6159
"t0": Condition(domain="t0", equation=Equation(initial_condition)),
6260
}
6361

pina/problem/zoo/helmholtz.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,12 @@ class HelmholtzProblem(SpatialProblem):
2828
spatial_domain = CartesianDomain({"x": [-1, 1], "y": [-1, 1]})
2929

3030
domains = {
31-
"D": CartesianDomain({"x": [-1, 1], "y": [-1, 1]}),
32-
"g1": CartesianDomain({"x": [-1, 1], "y": 1.0}),
33-
"g2": CartesianDomain({"x": [-1, 1], "y": -1.0}),
34-
"g3": CartesianDomain({"x": 1.0, "y": [-1, 1]}),
35-
"g4": CartesianDomain({"x": -1.0, "y": [-1, 1]}),
31+
"D": spatial_domain,
32+
"boundary": spatial_domain.partial(),
3633
}
3734

3835
conditions = {
39-
"g1": Condition(domain="g1", equation=FixedValue(0.0)),
40-
"g2": Condition(domain="g2", equation=FixedValue(0.0)),
41-
"g3": Condition(domain="g3", equation=FixedValue(0.0)),
42-
"g4": Condition(domain="g4", equation=FixedValue(0.0)),
36+
"boundary": Condition(domain="boundary", equation=FixedValue(0.0)),
4337
}
4438

4539
def __init__(self, alpha=3.0):

pina/problem/zoo/inverse_poisson_2d_square.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,13 @@ class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem):
9090
unknown_parameter_domain = CartesianDomain({"mu1": [-1, 1], "mu2": [-1, 1]})
9191

9292
domains = {
93-
"g1": CartesianDomain({"x": [x_min, x_max], "y": y_max}),
94-
"g2": CartesianDomain({"x": [x_min, x_max], "y": y_min}),
95-
"g3": CartesianDomain({"x": x_max, "y": [y_min, y_max]}),
96-
"g4": CartesianDomain({"x": x_min, "y": [y_min, y_max]}),
97-
"D": CartesianDomain({"x": [x_min, x_max], "y": [y_min, y_max]}),
93+
"D": spatial_domain,
94+
"boundary": spatial_domain.partial(),
9895
}
9996

10097
conditions = {
101-
"g1": Condition(domain="g1", equation=FixedValue(0.0)),
102-
"g2": Condition(domain="g2", equation=FixedValue(0.0)),
103-
"g3": Condition(domain="g3", equation=FixedValue(0.0)),
104-
"g4": Condition(domain="g4", equation=FixedValue(0.0)),
10598
"D": Condition(domain="D", equation=Equation(laplace_equation)),
99+
"boundary": Condition(domain="boundary", equation=FixedValue(0.0)),
106100
}
107101

108102
def __init__(self, load=True, data_size=1.0):

pina/problem/zoo/poisson_2d_square.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,12 @@ class Poisson2DSquareProblem(SpatialProblem):
3636
spatial_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]})
3737

3838
domains = {
39-
"D": CartesianDomain({"x": [0, 1], "y": [0, 1]}),
40-
"g1": CartesianDomain({"x": [0, 1], "y": 1.0}),
41-
"g2": CartesianDomain({"x": [0, 1], "y": 0.0}),
42-
"g3": CartesianDomain({"x": 1.0, "y": [0, 1]}),
43-
"g4": CartesianDomain({"x": 0.0, "y": [0, 1]}),
39+
"D": spatial_domain,
40+
"boundary": spatial_domain.partial(),
4441
}
4542

4643
conditions = {
47-
"g1": Condition(domain="g1", equation=FixedValue(0.0)),
48-
"g2": Condition(domain="g2", equation=FixedValue(0.0)),
49-
"g3": Condition(domain="g3", equation=FixedValue(0.0)),
50-
"g4": Condition(domain="g4", equation=FixedValue(0.0)),
44+
"boundary": Condition(domain="boundary", equation=FixedValue(0.0)),
5145
"D": Condition(domain="D", equation=Poisson(forcing_term=forcing_term)),
5246
}
5347

tests/test_callback/test_metric_tracker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77

88
# make the problem
99
poisson_problem = Poisson()
10-
boundaries = ["g1", "g2", "g3", "g4"]
1110
n = 10
12-
poisson_problem.discretise_domain(n, "grid", domains=boundaries)
11+
poisson_problem.discretise_domain(n, "grid", domains="boundary")
1312
poisson_problem.discretise_domain(n, "grid", domains="D")
1413
model = FeedForward(
1514
len(poisson_problem.input_variables), len(poisson_problem.output_variables)

tests/test_callback/test_pina_progress_bar.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77

88
# make the problem
99
poisson_problem = Poisson()
10-
boundaries = ["g1", "g2", "g3", "g4"]
1110
n = 10
1211
condition_names = list(poisson_problem.conditions.keys())
13-
poisson_problem.discretise_domain(n, "grid", domains=boundaries)
12+
poisson_problem.discretise_domain(n, "grid", domains="boundary")
1413
poisson_problem.discretise_domain(n, "grid", domains="D")
1514
model = FeedForward(
1615
len(poisson_problem.input_variables), len(poisson_problem.output_variables)

0 commit comments

Comments
 (0)