Skip to content

Commit bf8dc39

Browse files
fix bug in laplace labels
1 parent ef29f0a commit bf8dc39

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

pina/operator.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -216,22 +216,21 @@ def scalar_laplace(output_, input_, components, d):
216216
if components is None:
217217
components = output_.labels
218218

219+
if not isinstance(components, list):
220+
components = [components]
221+
219222
if method == "divgrad":
220223
raise NotImplementedError("divgrad not implemented as method")
221224

222225
if method == "std":
223-
if len(components) == 1:
224-
result = scalar_laplace(output_, input_, components, d)
225-
labels = [f"dd{components[0]}"]
226-
227-
else:
228-
result = torch.empty(
229-
input_.shape[0], len(components), device=output_.device
230-
)
231-
labels = [None] * len(components)
232-
for idx, c in enumerate(components):
233-
result[:, idx] = scalar_laplace(output_, input_, c, d).flatten()
234-
labels[idx] = f"dd{c}"
226+
227+
result = torch.empty(
228+
input_.shape[0], len(components), device=output_.device
229+
)
230+
labels = [None] * len(components)
231+
for idx, c in enumerate(components):
232+
result[:, idx] = scalar_laplace(output_, input_, [c], d).flatten()
233+
labels[idx] = f"dd{c}"
235234

236235
result = result.as_subclass(LabelTensor)
237236
result.labels = labels

0 commit comments

Comments
 (0)