19
19
20
20
namespace custom_kernel {
21
21
22
+ template <typename T, typename Context>
23
+ void StackKernel (const Context& dev_ctx,
24
+ const std::vector<const phi::DenseTensor*>& x,
25
+ int axis,
26
+ phi::DenseTensor* y);
27
+
28
+ template <typename T, typename Context>
29
+ void GatherNdKernel (const Context& dev_ctx,
30
+ const phi::DenseTensor& x,
31
+ const phi::DenseTensor& index,
32
+ phi::DenseTensor* out);
33
+
34
+ template <typename T, typename Context>
35
+ void NonZeroKernel (const Context& dev_ctx,
36
+ const phi::DenseTensor& condition,
37
+ phi::DenseTensor* out);
38
+
22
39
template <typename T, typename Context>
23
40
void CastKernel (const Context& dev_ctx,
24
41
const phi::DenseTensor& x,
@@ -59,8 +76,10 @@ void IndexPutGradKernel(const Context& dev_ctx,
59
76
const phi::DenseTensor& x,
60
77
const std::vector<const phi::DenseTensor*>& indices,
61
78
const phi::DenseTensor& value,
79
+ const phi::DenseTensor& out_grad,
62
80
bool accumulate,
63
- phi::DenseTensor* out) {
81
+ phi::DenseTensor* x_grad,
82
+ phi::DenseTensor* value_grad) {
64
83
bool unsafe = true ;
65
84
66
85
std::vector<phi::DenseTensor*> tensor_list (indices.size ());
@@ -76,10 +95,49 @@ void IndexPutGradKernel(const Context& dev_ctx,
76
95
}
77
96
}
78
97
79
- EXEC_NPU_CMD (
80
- aclnnIndexPutImpl, dev_ctx, x, tensor_list, value, accumulate, unsafe);
81
- dev_ctx.template Alloc <T>(out);
82
- TensorCopy (dev_ctx, x, true , out);
98
+ if (x_grad) {
99
+ dev_ctx.template Alloc <T>(x_grad);
100
+ TensorCopy (dev_ctx, out_grad, true , x_grad);
101
+ phi::DenseTensorMeta value_zero_meta = {value.dtype (), value.dims ()};
102
+ phi::DenseTensor value_zero;
103
+ value_zero.set_meta (value_zero_meta);
104
+ dev_ctx.template Alloc <T>(&value_zero);
105
+ EXEC_NPU_CMD (aclnnInplaceZero, dev_ctx, value_zero);
106
+ EXEC_NPU_CMD (aclnnIndexPutImpl,
107
+ dev_ctx,
108
+ *x_grad,
109
+ tensor_list,
110
+ value_zero,
111
+ accumulate,
112
+ unsafe);
113
+ }
114
+
115
+ if (value_grad) {
116
+ dev_ctx.template Alloc <T>(value_grad);
117
+ if (tensor_list[0 ]->dtype () == phi::DataType::BOOL) {
118
+ // deal with bool indices
119
+ PADDLE_ENFORCE_EQ (
120
+ tensor_list.size (),
121
+ 1 ,
122
+ phi::errors::InvalidArgument (" bool indices should be 1d" ));
123
+
124
+ phi::DenseTensor non_zero_index;
125
+ custom_kernel::NonZeroKernel<int64_t , Context>(
126
+ dev_ctx, *tensor_list[0 ], &non_zero_index);
127
+ custom_kernel::GatherNdKernel<T, Context>(
128
+ dev_ctx, out_grad, non_zero_index, value_grad);
129
+ } else {
130
+ phi::DenseTensorMeta index_tensor_meta = {
131
+ tensor_list[0 ]->dtype (),
132
+ phi::make_ddim ({tensor_list[0 ]->dims ()[0 ], tensor_list.size ()})};
133
+ phi::DenseTensor index_tensor;
134
+ index_tensor.set_meta (index_tensor_meta);
135
+ custom_kernel::StackKernel<int64_t , Context>(
136
+ dev_ctx, indices, -1 , &index_tensor);
137
+ custom_kernel::GatherNdKernel<T, Context>(
138
+ dev_ctx, out_grad, index_tensor, value_grad);
139
+ }
140
+ }
83
141
}
84
142
85
143
} // namespace custom_kernel
0 commit comments