Skip to content

Commit 94dff26

Browse files
authored
[Feature] Add the support of three_interpolate op for Ascend device (#2962)
1 parent c0774b5 commit 94dff26

File tree

2 files changed

+59
-11
lines changed

2 files changed

+59
-11
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include "pytorch_npu_helper.hpp"
2+
3+
using namespace NPU_NAME_SPACE;
4+
using namespace std;
5+
6+
void three_interpolate_forward_npu(int b, int c, int m, int n,
7+
const Tensor points, const Tensor idx,
8+
const Tensor weight, Tensor out) {
9+
auto point_c_trans = points.transpose(1, 2);
10+
11+
OpCommand cmd;
12+
cmd.Name("ThreeInterpolate")
13+
.Input(point_c_trans)
14+
.Input(idx)
15+
.Input(weight)
16+
.Output(out)
17+
.Run();
18+
19+
auto output = out.view({b, n, c}).transpose(1, 2);
20+
auto res = NpuUtils::format_contiguous(output);
21+
out.copy_(res);
22+
}
23+
24+
void three_interpolate_forward_impl(int b, int c, int m, int n,
25+
const Tensor points, const Tensor idx,
26+
const Tensor weight, Tensor out);
27+
28+
REGISTER_NPU_IMPL(three_interpolate_forward_impl,
29+
three_interpolate_forward_npu);

tests/test_ops/test_three_interpolate.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,28 @@
33
import torch
44

55
from mmcv.ops import three_interpolate
6+
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
67

78

8-
@pytest.mark.skipif(
9-
not torch.cuda.is_available(), reason='requires CUDA support')
10-
@pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double])
11-
def test_three_interpolate(dtype):
9+
@pytest.mark.parametrize('dtype', [
10+
torch.half, torch.float,
11+
pytest.param(
12+
torch.double,
13+
marks=pytest.mark.skipif(
14+
IS_NPU_AVAILABLE,
15+
reason='NPU does not support for 64-bit floating point'))
16+
])
17+
@pytest.mark.parametrize('device', [
18+
pytest.param(
19+
'cuda',
20+
marks=pytest.mark.skipif(
21+
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
22+
pytest.param(
23+
'npu',
24+
marks=pytest.mark.skipif(
25+
not IS_NPU_AVAILABLE, reason='requires NPU support'))
26+
])
27+
def test_three_interpolate(dtype, device):
1228
features = torch.tensor(
1329
[[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350],
1430
[3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236],
@@ -20,12 +36,13 @@ def test_three_interpolate(dtype):
2036
[0.0000, 0.2744, 2.0842, 0.0000, 0.0000, 0.0000],
2137
[0.3414, 1.5063, 1.6209, 0.3414, 0.3414, 0.3414],
2238
[0.5814, 0.0103, 0.0000, 0.5814, 0.5814, 0.5814]]],
23-
dtype=dtype).cuda()
39+
dtype=dtype,
40+
device=device)
2441

25-
idx = torch.tensor([[[0, 1, 2], [2, 3, 4], [2, 3, 4], [0, 1, 2], [0, 1, 2],
26-
[0, 1, 3]],
27-
[[0, 2, 3], [1, 3, 4], [2, 1, 4], [0, 2, 4], [0, 2, 4],
28-
[0, 1, 2]]]).int().cuda()
42+
idx = torch.tensor(
43+
[[[0, 1, 2], [2, 3, 4], [2, 3, 4], [0, 1, 2], [0, 1, 2], [0, 1, 3]],
44+
[[0, 2, 3], [1, 3, 4], [2, 1, 4], [0, 2, 4], [0, 2, 4], [0, 1, 2]]],
45+
device=device).int()
2946

3047
weight = torch.tensor([[[3.3333e-01, 3.3333e-01, 3.3333e-01],
3148
[1.0000e+00, 5.8155e-08, 2.2373e-08],
@@ -39,7 +56,8 @@ def test_three_interpolate(dtype):
3956
[3.3333e-01, 3.3333e-01, 3.3333e-01],
4057
[3.3333e-01, 3.3333e-01, 3.3333e-01],
4158
[3.3333e-01, 3.3333e-01, 3.3333e-01]]],
42-
dtype=dtype).cuda()
59+
dtype=dtype,
60+
device=device)
4361

4462
output = three_interpolate(features, idx, weight)
4563
expected_output = torch.tensor([[[
@@ -73,6 +91,7 @@ def test_three_interpolate(dtype):
7391
3.8760e-01, 1.0300e-02, 8.3569e-09,
7492
3.8760e-01, 3.8760e-01, 1.9723e-01
7593
]]],
76-
dtype=dtype).cuda()
94+
dtype=dtype,
95+
device=device)
7796

7897
assert torch.allclose(output, expected_output, 1e-3, 1e-4)

0 commit comments

Comments
 (0)