Skip to content

Commit 2647a18

Browse files
authored
feat: allow additional inputs to be passed to IVP (#225)
1 parent 720e23e commit 2647a18

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

neurodiffeq/conditions.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,21 +230,23 @@ class IVP(BaseCondition):
230230
231231
:param t_0: The initial time.
232232
:type t_0: float
233-
:param u_0: The initial value of :math:`u`. :math:`u(t_0)=u_0`.
234-
:type u_0: float
233+
:param u_0: If a float, this is the initial value of :math:`u`. :math:`u(t_0)=u_0`. If a callable, this is the
234+
:math:`u(t_0, x, y, ...)=u_0(x, y, ...)` function that takes additional tensors as inputs.
235+
:type u_0: float or callable.
235236
:param u_0_prime:
236237
The initial derivative of :math:`u` w.r.t. :math:`t`.
237238
:math:`\displaystyle\frac{\partial u}{\partial t}\bigg|_{t = t_0} = u_0'`.
239+
Similar to `u_0` this can be a function of additional tensors, i.e., `u_0_prime(x, y, ...)`.
238240
Defaults to None.
239-
:type u_0_prime: float, optional
241+
:type u_0_prime: float or callable, optional
240242
"""
241243

242244
@deprecated_alias(x_0='u_0', x_0_prime='u_0_prime')
243245
def __init__(self, t_0, u_0=None, u_0_prime=None):
244246
super().__init__()
245247
self.t_0, self.u_0, self.u_0_prime = t_0, u_0, u_0_prime
246248

247-
def parameterize(self, output_tensor, t):
249+
def parameterize(self, output_tensor, t, *additional_tensors):
248250
r"""Re-parameterizes outputs such that the Dirichlet/Neumann condition is satisfied.
249251
250252
- For Dirichlet condition, the re-parameterization is
@@ -258,13 +260,18 @@ def parameterize(self, output_tensor, t):
258260
:type output_tensor: `torch.Tensor`
259261
:param t: Input to the neural network; i.e., sampled time-points; i.e., independent variables.
260262
:type t: `torch.Tensor`
263+
:param additional_tensors: Additional inputs to the neural network. If u_0 or u_0_prime are `callable`s,
264+
the additional tensors will be passed to u_0 and u_0_prime.
265+
:type additional_tensors: `torch.Tensor`
261266
:return: The re-parameterized output of the network.
262267
:rtype: `torch.Tensor`
263268
"""
269+
u_0 = self.u_0(*additional_tensors) if callable(self.u_0) else self.u_0
264270
if self.u_0_prime is None:
265-
return self.u_0 + (1 - torch.exp(-t + self.t_0)) * output_tensor
271+
return u_0 + (1 - torch.exp(-t + self.t_0)) * output_tensor
266272
else:
267-
return self.u_0 + (t - self.t_0) * self.u_0_prime + ((1 - torch.exp(-t + self.t_0)) ** 2) * output_tensor
273+
u_0_prime = self.u_0_prime(*additional_tensors) if callable(self.u_0_prime) else self.u_0_prime
274+
return u_0 + (t - self.t_0) * u_0_prime + ((1 - torch.exp(-t + self.t_0)) ** 2) * output_tensor
268275

269276

270277
class BundleIVP(BaseCondition, _BundleConditionMixin):

0 commit comments

Comments
 (0)