Skip to content

Commit 98d9491

Browse files
fix bugs for helmholtz and advection
1 parent 31f91fc commit 98d9491

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

pina/equation/equation_factory.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,12 @@ def equation(input_, output_):
246246
)
247247

248248
# Repeat c to ensure consistent shape for advection
249-
self.c = self.c.repeat(output_.shape[0], 1)
250-
if self.c.shape[1] != (len(input_lbl) - 1):
251-
self.c = self.c.repeat(1, len(input_lbl) - 1)
249+
c = self.c.repeat(output_.shape[0], 1)
250+
if c.shape[1] != (len(input_lbl) - 1):
251+
c = c.repeat(1, len(input_lbl) - 1)
252252

253253
# Add a dimension to c for the following operations
254-
self.c = self.c.unsqueeze(-1)
254+
c = c.unsqueeze(-1)
255255

256256
# Compute the time derivative and the spatial gradient
257257
time_der = grad(output_, input_, components=None, d="t")
@@ -262,7 +262,7 @@ def equation(input_, output_):
262262
tmp = tmp.transpose(-1, -2)
263263

264264
# Compute advection term
265-
adv = (tmp * self.c).sum(dim=tmp.tensor.ndim - 2)
265+
adv = (tmp * c).sum(dim=tmp.tensor.ndim - 2)
266266

267267
return time_der + adv
268268

pina/problem/zoo/helmholtz.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,10 @@ def __init__(self, alpha=3.0):
4848
:type alpha: float | int
4949
"""
5050
super().__init__()
51-
52-
self.alpha = alpha
5351
check_consistency(alpha, (int, float))
52+
self.alpha = alpha
5453

55-
def forcing_term(self, input_):
54+
def forcing_term(input_):
5655
"""
5756
Implementation of the forcing term.
5857
"""

0 commit comments

Comments
 (0)