diff --git a/pina/equation/equation_factory.py b/pina/equation/equation_factory.py index 057ea65d4..da9c55647 100644 --- a/pina/equation/equation_factory.py +++ b/pina/equation/equation_factory.py @@ -239,19 +239,19 @@ def equation(input_, output_): ) # Ensure consistency of c length - if len(self.c) != (len(input_lbl) - 1) and len(self.c) > 1: + if self.c.shape[-1] != len(input_lbl) - 1 and self.c.shape[-1] > 1: raise ValueError( "If 'c' is passed as a list, its length must be equal to " "the number of spatial dimensions." ) # Repeat c to ensure consistent shape for advection - self.c = self.c.repeat(output_.shape[0], 1) - if self.c.shape[1] != (len(input_lbl) - 1): - self.c = self.c.repeat(1, len(input_lbl) - 1) + c = self.c.repeat(output_.shape[0], 1) + if c.shape[1] != (len(input_lbl) - 1): + c = c.repeat(1, len(input_lbl) - 1) # Add a dimension to c for the following operations - self.c = self.c.unsqueeze(-1) + c = c.unsqueeze(-1) # Compute the time derivative and the spatial gradient time_der = grad(output_, input_, components=None, d="t") @@ -262,7 +262,7 @@ def equation(input_, output_): tmp = tmp.transpose(-1, -2) # Compute advection term - adv = (tmp * self.c).sum(dim=tmp.tensor.ndim - 2) + adv = (tmp * c).sum(dim=tmp.tensor.ndim - 2) return time_der + adv diff --git a/pina/problem/zoo/helmholtz.py b/pina/problem/zoo/helmholtz.py index 5f3f956af..0f38780c7 100644 --- a/pina/problem/zoo/helmholtz.py +++ b/pina/problem/zoo/helmholtz.py @@ -48,11 +48,10 @@ def __init__(self, alpha=3.0): :type alpha: float | int """ super().__init__() - - self.alpha = alpha check_consistency(alpha, (int, float)) + self.alpha = alpha - def forcing_term(self, input_): + def forcing_term(input_): """ Implementation of the forcing term. """ diff --git a/tests/test_equation/test_equation_factory.py b/tests/test_equation/test_equation_factory.py index 4a9875115..be01427cb 100644 --- a/tests/test_equation/test_equation_factory.py +++ b/tests/test_equation/test_equation_factory.py @@ -104,7 +104,7 @@ def test_advection_equation(c): # Should fail if c is a list and its length != spatial dimension with pytest.raises(ValueError): - Advection([1, 2, 3]) + equation = Advection([1, 2, 3]) residual = equation.residual(pts, u)