Skip to content

Commit c0774b5

Browse files
authored
[Feature] Add the support of BallQuery op for Ascend device (#2963)
1 parent 57c4e25 commit c0774b5

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "pytorch_npu_helper.hpp"
2+
3+
using namespace NPU_NAME_SPACE;
4+
using namespace std;
5+
6+
void ball_query_forward_npu(int b, int n, int m, float min_radius,
7+
float max_radius, int nsample, const Tensor new_xyz,
8+
const Tensor xyz, Tensor idx) {
9+
int64_t nsample_i64 = nsample;
10+
11+
// transpose new_xyz from [B, M, 3] to [M, B, 3]
12+
at::Tensor new_xyz_transpose = new_xyz.transpose(0, 1);
13+
14+
// transpose xyz from [B, N, 3] to [B, 3, N]
15+
at::Tensor xyz_transpose = xyz.transpose(1, 2);
16+
17+
// transpose idx from [B, M, nsample] to [M, B, nsample]
18+
at::Tensor idx_transpose = NpuUtils::format_contiguous(idx.transpose(0, 1));
19+
20+
OpCommand cmd;
21+
cmd.Name("BallQuery")
22+
.Input(xyz_transpose)
23+
.Input(new_xyz_transpose)
24+
.Output(idx_transpose)
25+
.Attr("min_radius", min_radius)
26+
.Attr("max_radius", max_radius)
27+
.Attr("sample_num", nsample_i64)
28+
.Run();
29+
30+
idx_transpose = NpuUtils::format_contiguous(idx_transpose.transpose(0, 1));
31+
idx.copy_(idx_transpose);
32+
}
33+
34+
void ball_query_forward_impl(int b, int n, int m, float min_radius,
35+
float max_radius, int nsample,
36+
const Tensor new_xyz, const Tensor xyz,
37+
Tensor idx);
38+
39+
REGISTER_NPU_IMPL(ball_query_forward_impl, ball_query_forward_npu);

tests/test_ops/test_ball_query.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55
from mmcv.ops import ball_query
6-
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
6+
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
77

88

99
@pytest.mark.parametrize('device', [
@@ -14,7 +14,11 @@
1414
pytest.param(
1515
'mlu',
1616
marks=pytest.mark.skipif(
17-
not IS_MLU_AVAILABLE, reason='requires MLU support'))
17+
not IS_MLU_AVAILABLE, reason='requires MLU support')),
18+
pytest.param(
19+
'npu',
20+
marks=pytest.mark.skipif(
21+
not IS_NPU_AVAILABLE, reason='requires NPU support'))
1822
])
1923
def test_ball_query(device):
2024
new_xyz = torch.tensor(

0 commit comments

Comments
 (0)