@@ -75,10 +75,8 @@ def _scalar_grad(output_, input_, d):
7575 components = components or output_ .labels
7676
7777 # Convert to list if not already
78- d = [d ] if not isinstance (d , list ) else d
79- components = (
80- [components ] if not isinstance (components , list ) else components
81- )
78+ d = d if isinstance (d , list ) else [d ]
79+ components = components if isinstance (components , list ) else [components ]
8280
8381 # Check if all labels are present in the input tensor
8482 if not all (di in input_ .labels for di in d ):
@@ -142,10 +140,8 @@ def div(output_, input_, components=None, d=None):
142140 components = components or output_ .labels
143141
144142 # Convert to list if not already
145- d = [d ] if not isinstance (d , list ) else d
146- components = (
147- [components ] if not isinstance (components , list ) else components
148- )
143+ d = d if isinstance (d , list ) else [d ]
144+ components = components if isinstance (components , list ) else [components ]
149145
150146 # Components and d must be of the same length
151147 if len (components ) != len (d ):
@@ -205,13 +201,8 @@ def _scalar_laplacian(output_, input_, d):
205201 :rtype: LabelTensor
206202 """
207203 first_grad = grad (output_ = output_ , input_ = input_ , d = d )
208- result = torch .zeros (output_ .shape [0 ], 1 , device = output_ .device )
209-
210- second_grad = grad (output_ = first_grad , input_ = input_ , d = d ).T
211- for i in range (second_grad .shape [0 ]):
212- result [:, 0 ] += second_grad [i ]
213-
214- return result
204+ second_grad = grad (output_ = first_grad , input_ = input_ , d = d )
205+ return torch .sum (second_grad , dim = 1 , keepdim = True )
215206
216207 # Check if the input is a LabelTensor
217208 if not isinstance (input_ , LabelTensor ):
@@ -222,10 +213,8 @@ def _scalar_laplacian(output_, input_, d):
222213 components = components or output_ .labels
223214
224215 # Convert to list if not already
225- d = [d ] if not isinstance (d , list ) else d
226- components = (
227- [components ] if not isinstance (components , list ) else components
228- )
216+ d = d if isinstance (d , list ) else [d ]
217+ components = components if isinstance (components , list ) else [components ]
229218
230219 # Scalar laplacian
231220 if output_ .shape [1 ] == 1 :
@@ -242,20 +231,30 @@ def _scalar_laplacian(output_, input_, d):
242231
243232 # Vector laplacian
244233 if method == "std" :
245- for idx , c in enumerate (components ):
246- result [:, idx ] = _scalar_laplacian (
247- output_ = output_ .extract (c ), input_ = input_ , d = d
248- ).flatten ()
234+ result = torch .stack (
235+ [
236+ _scalar_laplacian (
237+ output_ = output_ .extract (c ), input_ = input_ , d = d
238+ ).flatten ()
239+ for c in components
240+ ],
241+ dim = 1 ,
242+ )
249243
250244 elif method == "divgrad" :
251245 grads = grad (output_ = output_ , input_ = input_ , components = components , d = d )
252- for idx , c in enumerate (components ):
253- result [:, idx ] = div (
254- output_ = grads ,
255- input_ = input_ ,
256- components = [f"d{ c } d{ i } " for i in d ],
257- d = d ,
258- ).flatten ()
246+ result = torch .stack (
247+ [
248+ div (
249+ output_ = grads ,
250+ input_ = input_ ,
251+ components = [f"d{ c } d{ i } " for i in d ],
252+ d = d ,
253+ ).flatten ()
254+ for c in components
255+ ],
256+ dim = 1 ,
257+ )
259258
260259 else :
261260 raise ValueError (
@@ -299,10 +298,8 @@ def advection(output_, input_, velocity_field, components=None, d=None):
299298 components = components or output_ .labels
300299
301300 # Convert to list if not already
302- d = [d ] if not isinstance (d , list ) else d
303- components = (
304- [components ] if not isinstance (components , list ) else components
305- )
301+ d = d if isinstance (d , list ) else [d ]
302+ components = components if isinstance (components , list ) else [components ]
306303
307304 # Check if velocity field is present in the output labels
308305 if velocity_field not in output_ .labels :
0 commit comments