@@ -170,53 +170,66 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
170170 is calculated. d should be a subset of the input labels. If None, all
171171 the input variables are considered. Default is None.
172172 :param str method: used method to calculate Laplacian, defaults to 'std'.
173- :raises ValueError: for vectorial field derivative with respect to
174- all coordinates must be performed.
173+
175174 :raises NotImplementedError: 'divgrad' not implemented as method.
176175 :return: The tensor containing the result of the Laplacian operator.
177176 :rtype: LabelTensor
178177 """
178+
179+ def scalar_laplace (output_ , input_ , components , d ):
180+ """
181+ Compute Laplace operator for a scalar output.
182+
183+ :param LabelTensor output_: the output tensor onto which computing the
184+ Laplacian. It has to be a column tensor.
185+ :param LabelTensor input_: the input tensor with respect to which
186+ computing the Laplacian.
187+ :param list(str) components: the name of the output variables to
188+ calculate the Laplacian for. It should be a subset of the output
189+ labels. If None, all the output variables are considered.
190+ :param list(str) d: the name of the input variables on which the
191+ Laplacian is computed. d should be a subset of the input labels.
192+ If None, all the input variables are considered. Default is None.
193+
194+ :return: The tensor containing the result of the Laplacian operator.
195+ :rtype: LabelTensor
196+ """
197+
198+ grad_output = grad (output_ , input_ , components = components , d = d )
199+ result = torch .zeros (output_ .shape [0 ], 1 , device = output_ .device )
200+
201+ for i , label in enumerate (grad_output .labels ):
202+ gg = grad (grad_output , input_ , d = d , components = [label ])
203+ result [:, 0 ] += super (torch .Tensor , gg .T ).__getitem__ (i )
204+
205+ return result
206+
179207 if d is None :
180208 d = input_ .labels
181209
182210 if components is None :
183211 components = output_ .labels
184212
185- if len (components ) != len (d ) and len (components ) != 1 :
186- raise ValueError
187-
188213 if method == "divgrad" :
189214 raise NotImplementedError ("divgrad not implemented as method" )
190215 # TODO fix
191216 # grad_output = grad(output_, input_, components, d)
192217 # result = div(grad_output, input_, d=d)
193- elif method == "std" :
194218
219+ elif method == "std" :
195220 if len (components ) == 1 :
196- grad_output = grad (output_ , input_ , components = components , d = d )
197- result = torch .zeros (output_ .shape [0 ], 1 , device = output_ .device )
198- for i , label in enumerate (grad_output .labels ):
199- gg = grad (grad_output , input_ , d = d , components = [label ])
200- result [:, 0 ] += super (torch .Tensor , gg .T ).__getitem__ (
201- i
202- ) # TODO improve
221+ result = scalar_laplace (output_ , input_ , components , d )
203222 labels = [f"dd{ components [0 ]} " ]
204223
205224 else :
206225 result = torch .empty (
207- input_ .shape [0 ], len (components ), device = output_ .device
226+ size = (input_ .shape [0 ], len (components )),
227+ dtype = output_ .dtype , device = output_ .device
208228 )
209229 labels = [None ] * len (components )
210- for idx , (ci , di ) in enumerate (zip (components , d )):
211-
212- if not isinstance (ci , list ):
213- ci = [ci ]
214- if not isinstance (di , list ):
215- di = [di ]
216-
217- grad_output = grad (output_ , input_ , components = ci , d = di )
218- result [:, idx ] = grad (grad_output , input_ , d = di ).flatten ()
219- labels [idx ] = f"dd{ ci } dd{ di } "
230+ for idx , c in enumerate (components ):
231+ result [:, idx ] = scalar_laplace (output_ , input_ , c , d ).flatten ()
232+ labels [idx ] = f"dd{ c } "
220233
221234 result = result .as_subclass (LabelTensor )
222235 result .labels = labels
0 commit comments