Skip to content

Commit 5e36d1b

Browse files
Additional tests for operators (#384)
* Adding tests for operators
1 parent b896a2a commit 5e36d1b

File tree

1 file changed

+66
-19
lines changed

1 file changed

+66
-19
lines changed

tests/test_operators.py

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,73 +5,120 @@
55
from pina.operators import grad, div, laplacian
66

77

8-
def func_vec(x):
8+
def func_vector(x):
99
return x**2
1010

1111

1212
def func_scalar(x):
13-
print('X')
1413
x_ = x.extract(['x'])
1514
y_ = x.extract(['y'])
16-
mu_ = x.extract(['mu'])
17-
return x_**2 + y_**2 + mu_**3
15+
z_ = x.extract(['z'])
16+
return x_**2 + y_**2 + z_**2
1817

1918

20-
data = torch.rand((20, 3), requires_grad=True)
21-
inp = LabelTensor(data, ['x', 'y', 'mu'])
22-
labels = ['a', 'b', 'c']
23-
tensor_v = LabelTensor(func_vec(inp), labels)
24-
tensor_s = LabelTensor(func_scalar(inp).reshape(-1, 1), labels[0])
25-
19+
inp = LabelTensor(torch.rand((20, 3), requires_grad=True), ['x', 'y', 'z'])
20+
tensor_v = LabelTensor(func_vector(inp), ['a', 'b', 'c'])
21+
tensor_s = LabelTensor(func_scalar(inp).reshape(-1, 1), ['a'])
2622

2723
def test_grad_scalar_output():
2824
grad_tensor_s = grad(tensor_s, inp)
25+
true_val = 2*inp
2926
assert grad_tensor_s.shape == inp.shape
3027
assert grad_tensor_s.labels == [
3128
f'd{tensor_s.labels[0]}d{i}' for i in inp.labels
3229
]
30+
assert torch.allclose(grad_tensor_s, true_val)
31+
3332
grad_tensor_s = grad(tensor_s, inp, d=['x', 'y'])
33+
true_val = 2*inp.extract(['x', 'y'])
3434
assert grad_tensor_s.shape == (inp.shape[0], 2)
3535
assert grad_tensor_s.labels == [
3636
f'd{tensor_s.labels[0]}d{i}' for i in ['x', 'y']
3737
]
38+
assert torch.allclose(grad_tensor_s, true_val)
3839

3940

4041
def test_grad_vector_output():
4142
grad_tensor_v = grad(tensor_v, inp)
43+
true_val = torch.cat(
44+
(2*inp.extract(['x']),
45+
torch.zeros_like(inp.extract(['y'])),
46+
torch.zeros_like(inp.extract(['z'])),
47+
torch.zeros_like(inp.extract(['x'])),
48+
2*inp.extract(['y']),
49+
torch.zeros_like(inp.extract(['z'])),
50+
torch.zeros_like(inp.extract(['x'])),
51+
torch.zeros_like(inp.extract(['y'])),
52+
2*inp.extract(['z'])
53+
), dim=1
54+
)
4255
assert grad_tensor_v.shape == (20, 9)
43-
grad_tensor_v = grad(tensor_v, inp, d=['x', 'mu'])
56+
assert grad_tensor_v.labels == [
57+
f'd{j}d{i}' for j in tensor_v.labels for i in inp.labels
58+
]
59+
assert torch.allclose(grad_tensor_v, true_val)
60+
61+
grad_tensor_v = grad(tensor_v, inp, d=['x', 'y'])
62+
true_val = torch.cat(
63+
(2*inp.extract(['x']),
64+
torch.zeros_like(inp.extract(['y'])),
65+
torch.zeros_like(inp.extract(['x'])),
66+
2*inp.extract(['y']),
67+
torch.zeros_like(inp.extract(['x'])),
68+
torch.zeros_like(inp.extract(['y']))
69+
), dim=1
70+
)
4471
assert grad_tensor_v.shape == (inp.shape[0], 6)
72+
assert grad_tensor_v.labels == [
73+
f'd{j}d{i}' for j in tensor_v.labels for i in ['x', 'y']
74+
]
75+
assert torch.allclose(grad_tensor_v, true_val)
4576

4677

4778
def test_div_vector_output():
48-
grad_tensor_v = div(tensor_v, inp)
49-
assert grad_tensor_v.shape == (20, 1)
50-
grad_tensor_v = div(tensor_v, inp, components=['a', 'b'], d=['x', 'mu'])
51-
assert grad_tensor_v.shape == (inp.shape[0], 1)
79+
div_tensor_v = div(tensor_v, inp)
80+
true_val = 2*torch.sum(inp, dim=1).reshape(-1,1)
81+
assert div_tensor_v.shape == (20, 1)
82+
assert div_tensor_v.labels == [f'dadx+dbdy+dcdz']
83+
assert torch.allclose(div_tensor_v, true_val)
84+
85+
div_tensor_v = div(tensor_v, inp, components=['a', 'b'], d=['x', 'y'])
86+
true_val = 2*torch.sum(inp.extract(['x', 'y']), dim=1).reshape(-1,1)
87+
assert div_tensor_v.shape == (inp.shape[0], 1)
88+
assert div_tensor_v.labels == [f'dadx+dbdy']
89+
assert torch.allclose(div_tensor_v, true_val)
5290

5391

5492
def test_laplacian_scalar_output():
55-
laplace_tensor_s = laplacian(tensor_s, inp, components=['a'], d=['x', 'y'])
93+
laplace_tensor_s = laplacian(tensor_s, inp)
94+
true_val = 6*torch.ones_like(laplace_tensor_s)
5695
assert laplace_tensor_s.shape == tensor_s.shape
5796
assert laplace_tensor_s.labels == [f"dd{tensor_s.labels[0]}"]
97+
assert torch.allclose(laplace_tensor_s, true_val)
98+
99+
laplace_tensor_s = laplacian(tensor_s, inp, components=['a'], d=['x', 'y'])
58100
true_val = 4*torch.ones_like(laplace_tensor_s)
59-
assert all((laplace_tensor_s - true_val == 0).flatten())
101+
assert laplace_tensor_s.shape == tensor_s.shape
102+
assert laplace_tensor_s.labels == [f"dd{tensor_s.labels[0]}"]
103+
assert torch.allclose(laplace_tensor_s, true_val)
60104

61105

62106
def test_laplacian_vector_output():
63107
laplace_tensor_v = laplacian(tensor_v, inp)
108+
true_val = 2*torch.ones_like(tensor_v)
64109
assert laplace_tensor_v.shape == tensor_v.shape
65110
assert laplace_tensor_v.labels == [
66111
f'dd{i}' for i in tensor_v.labels
67112
]
113+
assert torch.allclose(laplace_tensor_v, true_val)
114+
68115
laplace_tensor_v = laplacian(tensor_v,
69116
inp,
70117
components=['a', 'b'],
71118
d=['x', 'y'])
119+
true_val = 2*torch.ones_like(tensor_v.extract(['a', 'b']))
72120
assert laplace_tensor_v.shape == tensor_v.extract(['a', 'b']).shape
73121
assert laplace_tensor_v.labels == [
74122
f'dd{i}' for i in ['a', 'b']
75123
]
76-
true_val = 2*torch.ones_like(tensor_v.extract(['a', 'b']))
77-
assert all((laplace_tensor_v - true_val == 0).flatten())
124+
assert torch.allclose(laplace_tensor_v, true_val)

0 commit comments

Comments
 (0)