Skip to content

Commit 7469543

Browse files
dario-cosciaGiovanniCanali
authored andcommitted
simplify kwargs logic for equations
1 parent 684d691 commit 7469543

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

pina/equation/equation.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Module for the Equation."""
22

3+
import inspect
4+
35
from .equation_interface import EquationInterface
46

57

@@ -25,6 +27,9 @@ def __init__(self, equation):
2527
"Expected a callable function, got "
2628
f"{equation}"
2729
)
30+
# compute the signature
31+
sig = inspect.signature(equation)
32+
self.__len_sig = len(sig.parameters)
2833
self.__equation = equation
2934

3035
def residual(self, input_, output_, params_=None):
@@ -41,9 +46,14 @@ def residual(self, input_, output_, params_=None):
4146
parameters must be initialized to ``None``. Default is ``None``.
4247
:return: The computed residual of the equation.
4348
:rtype: LabelTensor
49+
:raises RuntimeError: If the underlying equation signature length is not
50+
2 (direct problem) or 3 (inverse problem).
4451
"""
45-
if params_ is None:
46-
result = self.__equation(input_, output_)
47-
else:
48-
result = self.__equation(input_, output_, params_)
49-
return result
52+
if self.__len_sig == 2:
53+
return self.__equation(input_, output_)
54+
if self.__len_sig == 3:
55+
return self.__equation(input_, output_, params_)
56+
raise RuntimeError(
57+
f"Unexpected number of arguments in equation: {self.__len_sig}. "
58+
"Expected either 2 (direct problem) or 3 (inverse problem)."
59+
)

pina/solver/physics_informed_solver/pinn_interface.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,9 @@ def compute_residual(self, samples, equation):
190190
:return: The residual of the solution of the model.
191191
:rtype: LabelTensor
192192
"""
193-
try:
194-
residual = equation.residual(samples, self.forward(samples))
195-
except TypeError:
196-
# this occurs when the function has three inputs (inverse problem)
197-
residual = equation.residual(
198-
samples, self.forward(samples), self._params
199-
)
193+
residual = equation.residual(
194+
samples, self.forward(samples), self._params
195+
)
200196
return residual
201197

202198
def _residual_loss(self, samples, equation):

0 commit comments

Comments
 (0)