diff --git a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp index 0f1b14e7dc..6832dc51f6 100644 --- a/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp @@ -1,4 +1,6 @@ #include "pytorch_npu_helper.hpp" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" using namespace NPU_NAME_SPACE; using namespace std; @@ -6,6 +8,10 @@ using namespace std; void three_interpolate_forward_npu(int b, int c, int m, int n, const Tensor points, const Tensor idx, const Tensor weight, Tensor out) { + auto originDtype = points.scalar_type(); + TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), + "three_interpolate_forward ascend only support fp32 and fp16."); + auto point_c_trans = points.transpose(1, 2); OpCommand cmd; @@ -21,9 +27,33 @@ void three_interpolate_forward_npu(int b, int c, int m, int n, out.copy_(res); } +void three_interpolate_backward_npu(int b, int c, int n, int m, + const Tensor grad_out, const Tensor idx, + const Tensor weight, Tensor grad_points) { + auto originDtype = grad_out.scalar_type(); + TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf), + "three_interpolate_backward ascend only support fp32 and fp16."); + + auto grad_x = at::unsqueeze(grad_out, 3); + auto grad_y = at::unsqueeze(grad_points, 3); + + EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight, m, grad_y); + + auto output = at::squeeze(grad_y, 3); + auto res = NpuUtils::format_contiguous(output); + grad_points.copy_(res); +} + void three_interpolate_forward_impl(int b, int c, int m, int n, const Tensor points, const Tensor idx, const Tensor weight, Tensor out); +void three_interpolate_backward_impl(int b, int c, int n, int m, + const Tensor grad_out, const Tensor idx, + const Tensor weight, Tensor grad_points); + REGISTER_NPU_IMPL(three_interpolate_forward_impl, three_interpolate_forward_npu); + +REGISTER_NPU_IMPL(three_interpolate_backward_impl, + three_interpolate_backward_npu); diff --git a/tests/test_ops/test_three_interpolate.py b/tests/test_ops/test_three_interpolate.py index d27a795ecf..3de6ddd769 100644 --- a/tests/test_ops/test_three_interpolate.py +++ b/tests/test_ops/test_three_interpolate.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import numpy as np import pytest import torch @@ -95,3 +96,81 @@ def test_three_interpolate(dtype, device): device=device) assert torch.allclose(output, expected_output, 1e-3, 1e-4) + + +def three_interpolate_forward_gloden(features, idx, weight): + bs, cs, ms = features.shape + ns = idx.shape[1] + + dtype = features.dtype + output = np.zeros((bs, cs, ns), dtype=dtype) + for b in range(bs): + for c in range(cs): + for n in range(ns): + output[b][c][n] = \ + features[b][c][idx[b][n][0]] * weight[b][n][0] \ + + features[b][c][idx[b][n][1]] * weight[b][n][1] \ + + features[b][c][idx[b][n][2]] * weight[b][n][2] + return output + + +def three_interpolate_backward_gloden(grad_output, idx, weight, features): + bs, cs, ns = grad_output.shape + ms = features.shape[2] + + dtype = features.dtype + grad_point = np.zeros((bs, cs, ms), dtype=dtype) + for b in range(bs): + for c in range(cs): + for n in range(ns): + grad_point[b][c][idx[b][n][0]] = \ + grad_point[b][c][idx[b][n][0]] + \ + grad_output[b][c][n] * weight[b][n][0] + grad_point[b][c][idx[b][n][1]] = \ + grad_point[b][c][idx[b][n][1]] + \ + grad_output[b][c][n] * weight[b][n][1] + grad_point[b][c][idx[b][n][2]] = \ + grad_point[b][c][idx[b][n][2]] + \ + grad_output[b][c][n] * weight[b][n][2] + return grad_point + + +def torch_type_trans(dtype): + if dtype == torch.half: + return np.float16 + elif dtype == torch.float: + return np.float32 + else: + return np.float64 + + +@pytest.mark.parametrize('dtype', [torch.half, torch.float]) +@pytest.mark.parametrize('device', [ + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) +]) +@pytest.mark.parametrize('shape', [(2, 5, 6, 6), (10, 10, 10, 10), + (20, 21, 13, 4), (2, 10, 2, 18), + (10, 602, 910, 200), (600, 100, 300, 101)]) +def test_three_interpolate_npu_dynamic_shape(dtype, device, shape): + bs = shape[0] + cs = shape[1] + ms = shape[2] + ns = shape[3] + + features = np.random.uniform(-10.0, 10.0, + (bs, cs, ms)).astype(torch_type_trans(dtype)) + idx = np.random.randint(0, ms, size=(bs, ns, 3), dtype=np.int32) + weight = np.random.uniform(-10.0, + 10.0 (bs, ns, + 3)).astype(torch_type_trans(dtype)) + + features_npu = torch.tensor(features, dtype=dtype).to(device) + idx_npu = torch.tensor(idx, dtype=torch.int32).to(device) + weight_npu = torch.tensor(weight, dtype=dtype).to(device) + + expected_output = three_interpolate_forward_gloden(features, idx, weight) + output = three_interpolate(features_npu, idx_npu, weight_npu) + assert np.allclose(output.cpu().numpy(), expected_output, 1e-3, 1e-4)