@@ -230,21 +230,23 @@ class IVP(BaseCondition):
230230
231231 :param t_0: The initial time.
232232 :type t_0: float
233- :param u_0: The initial value of :math:`u`. :math:`u(t_0)=u_0`.
234- :type u_0: float
233+ :param u_0: If a float, this is the initial value of :math:`u`. :math:`u(t_0)=u_0`. If a callable, this is the
234+ :math:`u(t_0, x, y, ...)=u_0(x, y, ...)` function that takes additional tensors as inputs.
235+ :type u_0: float or callable.
235236 :param u_0_prime:
236237 The initial derivative of :math:`u` w.r.t. :math:`t`.
237238 :math:`\displaystyle\frac{\partial u}{\partial t}\bigg|_{t = t_0} = u_0'`.
239+ Similar to `u_0` this can be a function of additional tensors, i.e., `u_0_prime(x, y, ...)`.
238240 Defaults to None.
239- :type u_0_prime: float, optional
241+ :type u_0_prime: float or callable , optional
240242 """
241243
242244 @deprecated_alias (x_0 = 'u_0' , x_0_prime = 'u_0_prime' )
243245 def __init__ (self , t_0 , u_0 = None , u_0_prime = None ):
244246 super ().__init__ ()
245247 self .t_0 , self .u_0 , self .u_0_prime = t_0 , u_0 , u_0_prime
246248
247- def parameterize (self , output_tensor , t ):
249+ def parameterize (self , output_tensor , t , * additional_tensors ):
248250 r"""Re-parameterizes outputs such that the Dirichlet/Neumann condition is satisfied.
249251
250252 - For Dirichlet condition, the re-parameterization is
@@ -258,13 +260,18 @@ def parameterize(self, output_tensor, t):
258260 :type output_tensor: `torch.Tensor`
259261 :param t: Input to the neural network; i.e., sampled time-points; i.e., independent variables.
260262 :type t: `torch.Tensor`
263+ :param additional_tensors: Additional inputs to the neural network. If u_0 or u_0_prime are `callable`s,
264+ the additional tensors will be passed to u_0 and u_0_prime.
265+ :type additional_tensors: `torch.Tensor`
261266 :return: The re-parameterized output of the network.
262267 :rtype: `torch.Tensor`
263268 """
269+ u_0 = self .u_0 (* additional_tensors ) if callable (self .u_0 ) else self .u_0
264270 if self .u_0_prime is None :
265- return self . u_0 + (1 - torch .exp (- t + self .t_0 )) * output_tensor
271+ return u_0 + (1 - torch .exp (- t + self .t_0 )) * output_tensor
266272 else :
267- return self .u_0 + (t - self .t_0 ) * self .u_0_prime + ((1 - torch .exp (- t + self .t_0 )) ** 2 ) * output_tensor
273+ u_0_prime = self .u_0_prime (* additional_tensors ) if callable (self .u_0_prime ) else self .u_0_prime
274+ return u_0 + (t - self .t_0 ) * u_0_prime + ((1 - torch .exp (- t + self .t_0 )) ** 2 ) * output_tensor
268275
269276
270277class BundleIVP (BaseCondition , _BundleConditionMixin ):
0 commit comments