Skip to content

Commit 9b65d57

Browse files
authored
[NPU] fix index_put_grad (#1425)
1 parent 01671f4 commit 9b65d57

File tree

1 file changed

+63
-5
lines changed

1 file changed

+63
-5
lines changed

backends/npu/kernels/index_put_kernel.cc

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@
1919

2020
namespace custom_kernel {
2121

22+
template <typename T, typename Context>
23+
void StackKernel(const Context& dev_ctx,
24+
const std::vector<const phi::DenseTensor*>& x,
25+
int axis,
26+
phi::DenseTensor* y);
27+
28+
template <typename T, typename Context>
29+
void GatherNdKernel(const Context& dev_ctx,
30+
const phi::DenseTensor& x,
31+
const phi::DenseTensor& index,
32+
phi::DenseTensor* out);
33+
34+
template <typename T, typename Context>
35+
void NonZeroKernel(const Context& dev_ctx,
36+
const phi::DenseTensor& condition,
37+
phi::DenseTensor* out);
38+
2239
template <typename T, typename Context>
2340
void CastKernel(const Context& dev_ctx,
2441
const phi::DenseTensor& x,
@@ -59,8 +76,10 @@ void IndexPutGradKernel(const Context& dev_ctx,
5976
const phi::DenseTensor& x,
6077
const std::vector<const phi::DenseTensor*>& indices,
6178
const phi::DenseTensor& value,
79+
const phi::DenseTensor& out_grad,
6280
bool accumulate,
63-
phi::DenseTensor* out) {
81+
phi::DenseTensor* x_grad,
82+
phi::DenseTensor* value_grad) {
6483
bool unsafe = true;
6584

6685
std::vector<phi::DenseTensor*> tensor_list(indices.size());
@@ -76,10 +95,49 @@ void IndexPutGradKernel(const Context& dev_ctx,
7695
}
7796
}
7897

79-
EXEC_NPU_CMD(
80-
aclnnIndexPutImpl, dev_ctx, x, tensor_list, value, accumulate, unsafe);
81-
dev_ctx.template Alloc<T>(out);
82-
TensorCopy(dev_ctx, x, true, out);
98+
if (x_grad) {
99+
dev_ctx.template Alloc<T>(x_grad);
100+
TensorCopy(dev_ctx, out_grad, true, x_grad);
101+
phi::DenseTensorMeta value_zero_meta = {value.dtype(), value.dims()};
102+
phi::DenseTensor value_zero;
103+
value_zero.set_meta(value_zero_meta);
104+
dev_ctx.template Alloc<T>(&value_zero);
105+
EXEC_NPU_CMD(aclnnInplaceZero, dev_ctx, value_zero);
106+
EXEC_NPU_CMD(aclnnIndexPutImpl,
107+
dev_ctx,
108+
*x_grad,
109+
tensor_list,
110+
value_zero,
111+
accumulate,
112+
unsafe);
113+
}
114+
115+
if (value_grad) {
116+
dev_ctx.template Alloc<T>(value_grad);
117+
if (tensor_list[0]->dtype() == phi::DataType::BOOL) {
118+
// deal with bool indices
119+
PADDLE_ENFORCE_EQ(
120+
tensor_list.size(),
121+
1,
122+
phi::errors::InvalidArgument("bool indices should be 1d"));
123+
124+
phi::DenseTensor non_zero_index;
125+
custom_kernel::NonZeroKernel<int64_t, Context>(
126+
dev_ctx, *tensor_list[0], &non_zero_index);
127+
custom_kernel::GatherNdKernel<T, Context>(
128+
dev_ctx, out_grad, non_zero_index, value_grad);
129+
} else {
130+
phi::DenseTensorMeta index_tensor_meta = {
131+
tensor_list[0]->dtype(),
132+
phi::make_ddim({tensor_list[0]->dims()[0], tensor_list.size()})};
133+
phi::DenseTensor index_tensor;
134+
index_tensor.set_meta(index_tensor_meta);
135+
custom_kernel::StackKernel<int64_t, Context>(
136+
dev_ctx, indices, -1, &index_tensor);
137+
custom_kernel::GatherNdKernel<T, Context>(
138+
dev_ctx, out_grad, index_tensor, value_grad);
139+
}
140+
}
83141
}
84142

85143
} // namespace custom_kernel

0 commit comments

Comments
 (0)