Skip to content

Commit 02b9fc8

Browse files
committed
fix
1 parent 9440105 commit 02b9fc8

File tree

1 file changed

+46
-35
lines changed

1 file changed

+46
-35
lines changed

tests/test_ops/test_three_interpolate.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -72,28 +72,28 @@ def test_three_interpolate(dtype, device):
7272
], [
7373
2.2060e-01, 3.4110e-01, 3.4110e-01, 2.2060e-01, 2.2060e-01, 2.1380e-01
7474
]],
75-
[[
76-
8.1773e-01, 9.5440e-01, 2.4532e+00,
77-
8.1773e-01, 8.1773e-01, 1.1359e+00
78-
],
79-
[
80-
8.4689e-01, 1.9176e+00, 1.4715e+00,
81-
8.4689e-01, 8.4689e-01, 1.3079e+00
82-
],
83-
[
84-
6.9473e-01, 2.7440e-01, 2.0842e+00,
85-
6.9473e-01, 6.9473e-01, 7.8619e-01
86-
],
87-
[
88-
7.6789e-01, 1.5063e+00, 1.6209e+00,
89-
7.6789e-01, 7.6789e-01, 1.1562e+00
90-
],
91-
[
92-
3.8760e-01, 1.0300e-02, 8.3569e-09,
93-
3.8760e-01, 3.8760e-01, 1.9723e-01
94-
]]],
95-
dtype=dtype,
96-
device=device)
75+
[[
76+
8.1773e-01, 9.5440e-01, 2.4532e+00,
77+
8.1773e-01, 8.1773e-01, 1.1359e+00
78+
],
79+
[
80+
8.4689e-01, 1.9176e+00, 1.4715e+00,
81+
8.4689e-01, 8.4689e-01, 1.3079e+00
82+
],
83+
[
84+
6.9473e-01, 2.7440e-01, 2.0842e+00,
85+
6.9473e-01, 6.9473e-01, 7.8619e-01
86+
],
87+
[
88+
7.6789e-01, 1.5063e+00, 1.6209e+00,
89+
7.6789e-01, 7.6789e-01, 1.1562e+00
90+
],
91+
[
92+
3.8760e-01, 1.0300e-02, 8.3569e-09,
93+
3.8760e-01, 3.8760e-01, 1.9723e-01
94+
]]],
95+
dtype=dtype,
96+
device=device)
9797

9898
assert torch.allclose(output, expected_output, 1e-3, 1e-4)
9999

@@ -106,16 +106,17 @@ def three_interpolate_forward_gloden(features, idx, weight):
106106
if dtype == np.float16:
107107
features = features.astype(np.float32)
108108
weight = weight.astype(np.float32)
109-
110-
output = np.zeros((bs, cs, ns), dtype = np.float)
109+
110+
output = np.zeros((bs, cs, ns), dtype=np.float)
111111
for b in range(bs):
112112
for c in range(cs):
113113
for n in range(ns):
114114
output[b][c][n] = features[b][c][idx[b][n][0]] * weight[b][n][0] \
115-
+ features[b][c][idx[b][n][1]] * weight[b][n][1] \
116-
+ features[b][c][idx[b][n][2]] * weight[b][n][2]
115+
+ features[b][c][idx[b][n][1]] * weight[b][n][1] \
116+
+ features[b][c][idx[b][n][2]] * weight[b][n][2]
117117
return output
118118

119+
119120
def three_interpolate_backward_gloden(grad_output, idx, weight, features):
120121
bs, cs, ns = grad_output.shape
121122
ms = features.shape[2]
@@ -124,7 +125,7 @@ def three_interpolate_backward_gloden(grad_output, idx, weight, features):
124125
if dtype == np.float16:
125126
features = features.astype(np.float32)
126127
weight = weight.astype(np.float32)
127-
128+
128129
grad_point = np.zeros((bs, cs, ms), dtype=np.float)
129130
for b in range(bs):
130131
for c in range(cs):
@@ -137,6 +138,7 @@ def three_interpolate_backward_gloden(grad_output, idx, weight, features):
137138
grad_output[b][c][n] * weight[b][n][2]
138139
return grad_point
139140

141+
140142
def torch_type_trans(dtype):
141143
if dtype == torch.half:
142144
return np.float16
@@ -145,27 +147,36 @@ def torch_type_trans(dtype):
145147
else:
146148
return np.float64
147149

150+
148151
@pytest.mark.parametrize('dtype', [
149152
torch.half,
150153
torch.float
151154
])
152155
@pytest.mark.parametrize('device', [
153-
(2,5,6,6),
154-
(10,10,10,10),
155-
(20,21,13,4),
156-
(2,10,2,18),
157-
(10,602,910,200),
158-
(600,100,300,101)
156+
pytest.param(
157+
'npu',
158+
marks=pytest.mark.skipif(
159+
not IS_NPU_AVAILABLE, reason='requires NPU support'))
160+
])
161+
@pytest.mark.parametrize('shape', [
162+
(2, 5, 6, 6),
163+
(10, 10, 10, 10),
164+
(20, 21, 13, 4),
165+
(2, 10, 2, 18),
166+
(10, 602, 910, 200),
167+
(600, 100, 300, 101)
159168
])
160169
def test_three_interpolate_npu_dynamic_shape(dtype, device, shape):
161170
bs = shape[0]
162171
cs = shape[1]
163172
ms = shape[2]
164173
ns = shape[3]
165174

166-
features = np.random.uniform(-10.0, 10.0, (bs, cs, ms).astype(torch_type_trans(dtype)))
175+
features = np.random.uniform(-10.0, 10.0,
176+
(bs, cs, ms).astype(torch_type_trans(dtype)))
167177
idx = np.random.uniform(0, ms, size=(bs, ns, 3), dtype=np.int32)
168-
weight = np.random.uniform(-10.0, 10.0 (bs, ns, 3)).astype(torch_type_trans(dtype))
178+
weight = np.random.uniform(-10.0, 10.0 (bs, ns, 3)
179+
).astype(torch_type_trans(dtype))
169180

170181
features_npu = torch.tensor(features, dtype=dtype).to(device)
171182
idx_npu = torch.tensor(idx, dtype=torch.int32).to(device)

0 commit comments

Comments
 (0)