@@ -27,12 +27,14 @@ def grad(output_, input_, components=None, d=None):
2727 computed.
2828 :param LabelTensor input_: The input tensor with respect to which the
2929 gradient is computed.
30- :param list[str] components: The names of the output variables for which to
30+ :param components: The names of the output variables for which to
3131 compute the gradient. It must be a subset of the output labels.
3232 If ``None``, all output variables are considered. Default is ``None``.
33- :param list[str] d: The names of the input variables with respect to which
33+ :type components: str | list[str]
34+ :param d: The names of the input variables with respect to which
3435 the gradient is computed. It must be a subset of the input labels.
3536 If ``None``, all input variables are considered. Default is ``None``.
37+ :type d: str | list[str]
3638 :raises TypeError: If the input tensor is not a LabelTensor.
3739 :raises RuntimeError: If the output is a scalar field and the components
3840 are not equal to the output labels.
@@ -50,9 +52,10 @@ def grad_scalar_output(output_, input_, d):
5052 computed. It must be a column tensor.
5153 :param LabelTensor input_: The input tensor with respect to which the
5254 gradient is computed.
53- :param list[str] d: The names of the input variables with respect to
55+ :param d: The names of the input variables with respect to
5456 which the gradient is computed. It must be a subset of the input
5557 labels. If ``None``, all input variables are considered.
58+ :type d: str | list[str]
5659 :raises RuntimeError: If a vectorial function is passed.
5760 :raises RuntimeError: If missing derivative labels.
5861 :return: The computed gradient tensor.
@@ -89,6 +92,12 @@ def grad_scalar_output(output_, input_, d):
8992 if components is None :
9093 components = output_ .labels
9194
95+ if not isinstance (components , list ):
96+ components = [components ]
97+
98+ if not isinstance (d , list ):
99+ d = [d ]
100+
92101 if output_ .shape [1 ] == 1 : # scalar output ################################
93102
94103 if components != output_ .labels :
@@ -120,12 +129,14 @@ def div(output_, input_, components=None, d=None):
120129 computed.
121130 :param LabelTensor input_: The input tensor with respect to which the
122131 divergence is computed.
123- :param list[str] components: The names of the output variables for which to
132+ :param components: The names of the output variables for which to
124133 compute the divergence. It must be a subset of the output labels.
125134 If ``None``, all output variables are considered. Default is ``None``.
126- :param list[str] d: The names of the input variables with respect to which
135+ :type components: str | list[str]
136+ :param d: The names of the input variables with respect to which
127137 the divergence is computed. It must be a subset of the input labels.
128138 If ``None``, all input variables are considered. Default is ``None``.
139+ :type d: str | list[str]
129140 :raises TypeError: If the input tensor is not a LabelTensor.
130141 :raises ValueError: If the output is a scalar field.
131142 :raises ValueError: If the number of components is not equal to the number
@@ -142,6 +153,12 @@ def div(output_, input_, components=None, d=None):
142153 if components is None :
143154 components = output_ .labels
144155
156+ if not isinstance (components , list ):
157+ components = [components ]
158+
159+ if not isinstance (d , list ):
160+ d = [d ]
161+
145162 if output_ .shape [1 ] < 2 or len (components ) < 2 :
146163 raise ValueError ("div supported only for vector fields" )
147164
@@ -170,12 +187,14 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
170187 computed.
171188 :param LabelTensor input_: The input tensor with respect to which the
172189 laplacian is computed.
173- :param list[str] components: The names of the output variables for which to
190+ :param components: The names of the output variables for which to
174191 compute the laplacian. It must be a subset of the output labels.
175192 If ``None``, all output variables are considered. Default is ``None``.
176- :param list[str] d: The names of the input variables with respect to which
193+ :type components: str | list[str]
194+ :param d: The names of the input variables with respect to which
177195 the laplacian is computed. It must be a subset of the input labels.
178196 If ``None``, all input variables are considered. Default is ``None``.
197+ :type d: str | list[str]
179198 :param str method: The method used to compute the Laplacian. Default is
180199 ``std``.
181200 :raises NotImplementedError: If ``std=divgrad``.
@@ -191,12 +210,14 @@ def scalar_laplace(output_, input_, components, d):
191210 computed. It must be a column tensor.
192211 :param LabelTensor input_: The input tensor with respect to which the
193212 laplacian is computed.
194- :param list[str] components: The names of the output variables for which
213+ :param components: The names of the output variables for which
195214 to compute the laplacian. It must be a subset of the output labels.
196215 If ``None``, all output variables are considered.
197- :param list[str] d: The names of the input variables with respect to
216+ :type components: str | list[str]
217+ :param d: The names of the input variables with respect to
198218 which the laplacian is computed. It must be a subset of the input
199219 labels. If ``None``, all input variables are considered.
220+ :type d: str | list[str]
200221 :return: The computed laplacian tensor.
201222 :rtype: LabelTensor
202223 """
@@ -216,22 +237,24 @@ def scalar_laplace(output_, input_, components, d):
216237 if components is None :
217238 components = output_ .labels
218239
240+ if not isinstance (components , list ):
241+ components = [components ]
242+
243+ if not isinstance (d , list ):
244+ d = [d ]
245+
219246 if method == "divgrad" :
220247 raise NotImplementedError ("divgrad not implemented as method" )
221248
222249 if method == "std" :
223- if len (components ) == 1 :
224- result = scalar_laplace (output_ , input_ , components , d )
225- labels = [f"dd{ components [0 ]} " ]
226-
227- else :
228- result = torch .empty (
229- input_ .shape [0 ], len (components ), device = output_ .device
230- )
231- labels = [None ] * len (components )
232- for idx , c in enumerate (components ):
233- result [:, idx ] = scalar_laplace (output_ , input_ , c , d ).flatten ()
234- labels [idx ] = f"dd{ c } "
250+
251+ result = torch .empty (
252+ input_ .shape [0 ], len (components ), device = output_ .device
253+ )
254+ labels = [None ] * len (components )
255+ for idx , c in enumerate (components ):
256+ result [:, idx ] = scalar_laplace (output_ , input_ , [c ], d ).flatten ()
257+ labels [idx ] = f"dd{ c } "
235258
236259 result = result .as_subclass (LabelTensor )
237260 result .labels = labels
@@ -251,12 +274,14 @@ def advection(output_, input_, velocity_field, components=None, d=None):
251274 is computed.
252275 :param str velocity_field: The name of the output variable used as velocity
253276 field. It must be chosen among the output labels.
254- :param list[str] components: The names of the output variables for which
277+ :param components: The names of the output variables for which
255278 to compute the advection. It must be a subset of the output labels.
256279 If ``None``, all output variables are considered. Default is ``None``.
257- :param list[str] d: The names of the input variables with respect to which
280+ :type components: str | list[str]
281+ :param d: The names of the input variables with respect to which
258282 the advection is computed. It must be a subset of the input labels.
259283 If ``None``, all input variables are considered. Default is ``None``.
284+ :type d: str | list[str]
260285 :return: The computed advection tensor.
261286 :rtype: LabelTensor
262287 """
@@ -266,6 +291,12 @@ def advection(output_, input_, velocity_field, components=None, d=None):
266291 if components is None :
267292 components = output_ .labels
268293
294+ if not isinstance (components , list ):
295+ components = [components ]
296+
297+ if not isinstance (d , list ):
298+ d = [d ]
299+
269300 tmp = (
270301 grad (output_ , input_ , components , d )
271302 .reshape (- 1 , len (components ), len (d ))
0 commit comments