Skip to content

Commit 8f23a0b

Browse files
Binary2355lizekai
andauthored
【Feature】knn/tnn npu added (#3125)
Co-authored-by: lizekai <[email protected]>
1 parent b91cfde commit 8f23a0b

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include "pytorch_npu_helper.hpp"
2+
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
3+
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
4+
5+
using namespace NPU_NAME_SPACE;
6+
using namespace std;
7+
8+
void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
9+
const Tensor new_xyz, Tensor idx, Tensor dist2) {
10+
// transpose known from [B, N, 3] to [B, 3, N]
11+
at::Tensor source = xyz.transpose(1, 2).contiguous();
12+
at::Tensor target = new_xyz.contiguous();
13+
14+
bool is_from_knn = true;
15+
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
16+
}
17+
18+
void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
19+
const Tensor new_xyz, Tensor idx, Tensor dist2);
20+
21+
REGISTER_NPU_IMPL(knn_forward_impl, knn_forward_npu);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include "pytorch_npu_helper.hpp"
2+
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
3+
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
4+
5+
using namespace NPU_NAME_SPACE;
6+
using namespace std;
7+
8+
void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
9+
const Tensor known, Tensor dist2, Tensor idx) {
10+
// transpose known [B, N, 3] -> [B, 3, N]
11+
at::Tensor source = known.transpose(1, 2).contiguous();
12+
at::Tensor target = unknown.contiguous();
13+
auto originDtype = source.scalar_type();
14+
if (originDtype == at::kHalf) {
15+
source = source.to(at::kFloat);
16+
target = target.to(at::kFloat);
17+
}
18+
19+
bool is_from_knn = false;
20+
uint32_t nsample = 3;
21+
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
22+
if (originDtype == at::kHalf) {
23+
dist2 = dist2.to(at::kHalf);
24+
}
25+
}
26+
27+
void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
28+
const Tensor known, Tensor dist2, Tensor idx);
29+
30+
REGISTER_NPU_IMPL(three_nn_forward_impl, three_nn_forward_npu);

mmcv/ops/knn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ def forward(ctx,
5555
center_xyz_device = center_xyz.get_device()
5656
assert center_xyz_device == xyz.get_device(), \
5757
'center_xyz and xyz should be put on the same device'
58-
if torch.cuda.current_device() != center_xyz_device:
59-
torch.cuda.set_device(center_xyz_device)
58+
if xyz.device.type != 'npu':
59+
if torch.cuda.current_device() != center_xyz_device:
60+
torch.cuda.set_device(center_xyz_device)
6061

6162
B, npoint, _ = center_xyz.shape
6263
N = xyz.shape[1]

0 commit comments

Comments
 (0)