Skip to content

Commit a78f44e

Browse files
Fixing Laplacian operator for vector fields (#380)
* fix laplacian and tests
1 parent db521ef commit a78f44e

File tree

2 files changed

+50
-26
lines changed

2 files changed

+50
-26
lines changed

pina/operators.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -170,53 +170,66 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
170170
is calculated. d should be a subset of the input labels. If None, all
171171
the input variables are considered. Default is None.
172172
:param str method: used method to calculate Laplacian, defaults to 'std'.
173-
:raises ValueError: for vectorial field derivative with respect to
174-
all coordinates must be performed.
173+
175174
:raises NotImplementedError: 'divgrad' not implemented as method.
176175
:return: The tensor containing the result of the Laplacian operator.
177176
:rtype: LabelTensor
178177
"""
178+
179+
def scalar_laplace(output_, input_, components, d):
180+
"""
181+
Compute Laplace operator for a scalar output.
182+
183+
:param LabelTensor output_: the output tensor onto which computing the
184+
Laplacian. It has to be a column tensor.
185+
:param LabelTensor input_: the input tensor with respect to which
186+
computing the Laplacian.
187+
:param list(str) components: the name of the output variables to
188+
calculate the Laplacian for. It should be a subset of the output
189+
labels. If None, all the output variables are considered.
190+
:param list(str) d: the name of the input variables on which the
191+
Laplacian is computed. d should be a subset of the input labels.
192+
If None, all the input variables are considered. Default is None.
193+
194+
:return: The tensor containing the result of the Laplacian operator.
195+
:rtype: LabelTensor
196+
"""
197+
198+
grad_output = grad(output_, input_, components=components, d=d)
199+
result = torch.zeros(output_.shape[0], 1, device=output_.device)
200+
201+
for i, label in enumerate(grad_output.labels):
202+
gg = grad(grad_output, input_, d=d, components=[label])
203+
result[:, 0] += super(torch.Tensor, gg.T).__getitem__(i)
204+
205+
return result
206+
179207
if d is None:
180208
d = input_.labels
181209

182210
if components is None:
183211
components = output_.labels
184212

185-
if len(components) != len(d) and len(components) != 1:
186-
raise ValueError
187-
188213
if method == "divgrad":
189214
raise NotImplementedError("divgrad not implemented as method")
190215
# TODO fix
191216
# grad_output = grad(output_, input_, components, d)
192217
# result = div(grad_output, input_, d=d)
193-
elif method == "std":
194218

219+
elif method == "std":
195220
if len(components) == 1:
196-
grad_output = grad(output_, input_, components=components, d=d)
197-
result = torch.zeros(output_.shape[0], 1, device=output_.device)
198-
for i, label in enumerate(grad_output.labels):
199-
gg = grad(grad_output, input_, d=d, components=[label])
200-
result[:, 0] += super(torch.Tensor, gg.T).__getitem__(
201-
i
202-
) # TODO improve
221+
result = scalar_laplace(output_, input_, components, d)
203222
labels = [f"dd{components[0]}"]
204223

205224
else:
206225
result = torch.empty(
207-
input_.shape[0], len(components), device=output_.device
226+
size=(input_.shape[0], len(components)),
227+
dtype=output_.dtype, device=output_.device
208228
)
209229
labels = [None] * len(components)
210-
for idx, (ci, di) in enumerate(zip(components, d)):
211-
212-
if not isinstance(ci, list):
213-
ci = [ci]
214-
if not isinstance(di, list):
215-
di = [di]
216-
217-
grad_output = grad(output_, input_, components=ci, d=di)
218-
result[:, idx] = grad(grad_output, input_, d=di).flatten()
219-
labels[idx] = f"dd{ci}dd{di}"
230+
for idx, c in enumerate(components):
231+
result[:, idx] = scalar_laplace(output_, input_, c, d).flatten()
232+
labels[idx] = f"dd{c}"
220233

221234
result = result.as_subclass(LabelTensor)
222235
result.labels = labels

tests/test_operators.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,26 @@ def test_div_vector_output():
5252

5353

5454
def test_laplacian_scalar_output():
55-
laplace_tensor_v = laplacian(tensor_s, inp, components=['a'], d=['x', 'y'])
56-
assert laplace_tensor_v.shape == tensor_s.shape
55+
laplace_tensor_s = laplacian(tensor_s, inp, components=['a'], d=['x', 'y'])
56+
assert laplace_tensor_s.shape == tensor_s.shape
57+
assert laplace_tensor_s.labels == [f"dd{tensor_s.labels[0]}"]
58+
true_val = 4*torch.ones_like(laplace_tensor_s)
59+
assert all((laplace_tensor_s - true_val == 0).flatten())
5760

5861

5962
def test_laplacian_vector_output():
6063
laplace_tensor_v = laplacian(tensor_v, inp)
6164
assert laplace_tensor_v.shape == tensor_v.shape
65+
assert laplace_tensor_v.labels == [
66+
f'dd{i}' for i in tensor_v.labels
67+
]
6268
laplace_tensor_v = laplacian(tensor_v,
6369
inp,
6470
components=['a', 'b'],
6571
d=['x', 'y'])
6672
assert laplace_tensor_v.shape == tensor_v.extract(['a', 'b']).shape
73+
assert laplace_tensor_v.labels == [
74+
f'dd{i}' for i in ['a', 'b']
75+
]
76+
true_val = 2*torch.ones_like(tensor_v.extract(['a', 'b']))
77+
assert all((laplace_tensor_v - true_val == 0).flatten())

0 commit comments

Comments
 (0)