Skip to content

Commit 348cfa7

Browse files
additional improvements
1 parent 5b1b5cb commit 348cfa7

File tree

1 file changed

+31
-34
lines changed

1 file changed

+31
-34
lines changed

pina/operator.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,8 @@ def _scalar_grad(output_, input_, d):
7575
components = components or output_.labels
7676

7777
# Convert to list if not already
78-
d = [d] if not isinstance(d, list) else d
79-
components = (
80-
[components] if not isinstance(components, list) else components
81-
)
78+
d = d if isinstance(d, list) else [d]
79+
components = components if isinstance(components, list) else [components]
8280

8381
# Check if all labels are present in the input tensor
8482
if not all(di in input_.labels for di in d):
@@ -142,10 +140,8 @@ def div(output_, input_, components=None, d=None):
142140
components = components or output_.labels
143141

144142
# Convert to list if not already
145-
d = [d] if not isinstance(d, list) else d
146-
components = (
147-
[components] if not isinstance(components, list) else components
148-
)
143+
d = d if isinstance(d, list) else [d]
144+
components = components if isinstance(components, list) else [components]
149145

150146
# Components and d must be of the same length
151147
if len(components) != len(d):
@@ -205,13 +201,8 @@ def _scalar_laplacian(output_, input_, d):
205201
:rtype: LabelTensor
206202
"""
207203
first_grad = grad(output_=output_, input_=input_, d=d)
208-
result = torch.zeros(output_.shape[0], 1, device=output_.device)
209-
210-
second_grad = grad(output_=first_grad, input_=input_, d=d).T
211-
for i in range(second_grad.shape[0]):
212-
result[:, 0] += second_grad[i]
213-
214-
return result
204+
second_grad = grad(output_=first_grad, input_=input_, d=d)
205+
return torch.sum(second_grad, dim=1, keepdim=True)
215206

216207
# Check if the input is a LabelTensor
217208
if not isinstance(input_, LabelTensor):
@@ -222,10 +213,8 @@ def _scalar_laplacian(output_, input_, d):
222213
components = components or output_.labels
223214

224215
# Convert to list if not already
225-
d = [d] if not isinstance(d, list) else d
226-
components = (
227-
[components] if not isinstance(components, list) else components
228-
)
216+
d = d if isinstance(d, list) else [d]
217+
components = components if isinstance(components, list) else [components]
229218

230219
# Scalar laplacian
231220
if output_.shape[1] == 1:
@@ -242,20 +231,30 @@ def _scalar_laplacian(output_, input_, d):
242231

243232
# Vector laplacian
244233
if method == "std":
245-
for idx, c in enumerate(components):
246-
result[:, idx] = _scalar_laplacian(
247-
output_=output_.extract(c), input_=input_, d=d
248-
).flatten()
234+
result = torch.stack(
235+
[
236+
_scalar_laplacian(
237+
output_=output_.extract(c), input_=input_, d=d
238+
).flatten()
239+
for c in components
240+
],
241+
dim=1,
242+
)
249243

250244
elif method == "divgrad":
251245
grads = grad(output_=output_, input_=input_, components=components, d=d)
252-
for idx, c in enumerate(components):
253-
result[:, idx] = div(
254-
output_=grads,
255-
input_=input_,
256-
components=[f"d{c}d{i}" for i in d],
257-
d=d,
258-
).flatten()
246+
result = torch.stack(
247+
[
248+
div(
249+
output_=grads,
250+
input_=input_,
251+
components=[f"d{c}d{i}" for i in d],
252+
d=d,
253+
).flatten()
254+
for c in components
255+
],
256+
dim=1,
257+
)
259258

260259
else:
261260
raise ValueError(
@@ -299,10 +298,8 @@ def advection(output_, input_, velocity_field, components=None, d=None):
299298
components = components or output_.labels
300299

301300
# Convert to list if not already
302-
d = [d] if not isinstance(d, list) else d
303-
components = (
304-
[components] if not isinstance(components, list) else components
305-
)
301+
d = d if isinstance(d, list) else [d]
302+
components = components if isinstance(components, list) else [components]
306303

307304
# Check if velocity field is present in the output labels
308305
if velocity_field not in output_.labels:

0 commit comments

Comments
 (0)