Skip to content

Commit 61c490d

Browse files
[NPU] fix ut question: test_zero_dim. (#1361)
1 parent 7d50a37 commit 61c490d

File tree

7 files changed

+125
-23
lines changed

7 files changed

+125
-23
lines changed

backends/npu/kernels/expand_kernel.cc

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,22 @@ void ExpandKernel(const Context& dev_ctx,
274274

275275
phi::DDim out_dims = phi::make_ddim(final_expand_shape);
276276
out->Resize(out_dims);
277-
dev_ctx.template Alloc<T>(out);
278277

279-
if (x.dtype() == phi::DataType::FLOAT64) {
278+
if (out_dims.size() == 0) {
279+
out->Resize(phi::make_ddim({1}));
280+
final_expand_shape = {1};
281+
}
282+
283+
phi::DenseTensor x_trans(x);
284+
if (x.dims().size() == 0) {
285+
phi::DenseTensorMeta meta_1 = {x.dtype(), phi::make_ddim({1})};
286+
x_trans.set_meta(meta_1);
287+
}
288+
289+
dev_ctx.template Alloc<T>(out);
290+
if (x_trans.dtype() == phi::DataType::FLOAT64) {
280291
phi::DenseTensor cast_x;
281-
phi::DenseTensorMeta cast_x_meta = {phi::DataType::FLOAT32, x.dims()};
292+
phi::DenseTensorMeta cast_x_meta = {phi::DataType::FLOAT32, x_trans.dims()};
282293
cast_x.set_meta(cast_x_meta);
283294
dev_ctx.template Alloc<float>(&cast_x);
284295

@@ -288,12 +299,16 @@ void ExpandKernel(const Context& dev_ctx,
288299
dev_ctx.template Alloc<float>(&cast_out);
289300

290301
custom_kernel::CastKernel<T, Context>(
291-
dev_ctx, x, phi::DataType::FLOAT32, &cast_x);
302+
dev_ctx, x_trans, phi::DataType::FLOAT32, &cast_x);
292303
EXEC_NPU_CMD(aclnnExpand, dev_ctx, cast_x, final_expand_shape, cast_out);
293304
custom_kernel::CastKernel<T, Context>(
294305
dev_ctx, cast_out, phi::DataType::FLOAT64, out);
295306
} else {
296-
EXEC_NPU_CMD(aclnnExpand, dev_ctx, x, final_expand_shape, *out);
307+
EXEC_NPU_CMD(aclnnExpand, dev_ctx, x_trans, final_expand_shape, *out);
308+
}
309+
310+
if (out_dims.size() == 0) {
311+
out->Resize(out_dims);
297312
}
298313
}
299314

backends/npu/kernels/gather_kernel.cc

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,30 @@ void GatherKernel(const Context& dev_ctx,
6565
DO_COMPATIBILITY(aclnnGatherV2,
6666
(custom_kernel::AclopGatherKernel<T, Context>(
6767
dev_ctx, x, index, axis, out)));
68-
auto stream = dev_ctx.stream();
69-
dev_ctx.template Alloc<T>(out);
7068

7169
int64_t dim = axis.to<int64_t>();
7270

71+
auto out_dims = out->dims();
72+
bool zero_index = false;
73+
int count = 0;
74+
if (index.dims().size() == 0) {
75+
std::vector<int64_t> out_dims_new(out_dims.size() + 1);
76+
for (int64_t i = 0; i <= out_dims_new.size() - 1; i++) {
77+
if (i == dim) {
78+
out_dims_new[i] = 1;
79+
} else {
80+
out_dims_new[i] = out_dims[count];
81+
count++;
82+
}
83+
}
84+
85+
phi::DenseTensorMeta meta_1 = {x.dtype(), phi::make_ddim(out_dims_new)};
86+
out->set_meta(meta_1);
87+
zero_index = true;
88+
}
89+
90+
dev_ctx.template Alloc<T>(out);
91+
7392
auto index_shape_vec = phi::vectorize(index.dims());
7493
if (index_shape_vec.size() == 2 && index_shape_vec[1] == 1) {
7594
const phi::DenseTensor* p_index = &index;
@@ -82,6 +101,11 @@ void GatherKernel(const Context& dev_ctx,
82101
} else {
83102
EXEC_NPU_CMD(aclnnGatherV2, dev_ctx, x, dim, index, *out);
84103
}
104+
105+
if (zero_index) {
106+
phi::DenseTensorMeta meta_0 = {x.dtype(), out_dims};
107+
out->set_meta(meta_0);
108+
}
85109
}
86110

87111
template <typename T, typename Context>
@@ -118,8 +142,27 @@ void GatherGradKernel(const Context& dev_ctx,
118142
zeroslike_xout.Resize(x.dims());
119143

120144
// step3: scatter(x_grad)
145+
phi::DenseTensor out_grad_(out_grad);
146+
147+
int64_t dim = axis.to<int64_t>();
148+
int count = 0;
149+
if (index.dims().size() == 0) {
150+
std::vector<int64_t> out_dims_new(out_grad.dims().size() + 1);
151+
for (int64_t i = 0; i <= out_dims_new.size() - 1; i++) {
152+
if (i == dim) {
153+
out_dims_new[i] = 1;
154+
} else {
155+
out_dims_new[i] = out_grad.dims()[count];
156+
count++;
157+
}
158+
}
159+
160+
phi::DenseTensorMeta meta_1 = {x.dtype(), phi::make_ddim(out_dims_new)};
161+
out_grad_.set_meta(meta_1);
162+
}
163+
121164
EXEC_NPU_CMD(
122-
aclnnScatterNd, dev_ctx, zeroslike_xout, *p_index, out_grad, *x_grad);
165+
aclnnScatterNd, dev_ctx, zeroslike_xout, *p_index, out_grad_, *x_grad);
123166
}
124167

125168
} // namespace custom_kernel

backends/npu/kernels/pad_kernel.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ PD_REGISTER_PLUGIN_KERNEL(pad,
7878
npu,
7979
ALL_LAYOUT,
8080
custom_kernel::PadKernel,
81-
int,
81+
int16_t,
82+
int32_t,
83+
int64_t,
8284
float,
8385
phi::dtype::float16,
8486
phi::dtype::bfloat16,
@@ -88,7 +90,9 @@ PD_REGISTER_PLUGIN_KERNEL(pad_grad,
8890
npu,
8991
ALL_LAYOUT,
9092
custom_kernel::PadGradKernel,
91-
int,
93+
int16_t,
94+
int32_t,
95+
int64_t,
9296
float,
9397
phi::dtype::float16,
9498
phi::dtype::bfloat16,

backends/npu/kernels/softmax_kernel.cc

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,23 @@ void SoftmaxKernel(const Context& dev_ctx,
5656
(custom_kernel::AclopSoftmaxKernel<T, Context>(dev_ctx, x, axis, out)));
5757
dev_ctx.template Alloc<T>(out);
5858
int64_t dim = static_cast<int64_t>(axis);
59-
EXEC_NPU_CMD(aclnnSoftmax, dev_ctx, x, dim, *out);
59+
60+
phi::DenseTensor x_trans(x);
61+
if (x.dims().size() == 0) {
62+
phi::DenseTensorMeta meta_1 = {x.dtype(), phi::make_ddim({1})};
63+
x_trans.set_meta(meta_1);
64+
}
65+
66+
auto out_dims = out->dims();
67+
if (out_dims.size() == 0) {
68+
out->Resize(phi::make_ddim({1}));
69+
}
70+
71+
EXEC_NPU_CMD(aclnnSoftmax, dev_ctx, x_trans, dim, *out);
72+
73+
if (out_dims.size() == 0) {
74+
out->Resize(out_dims);
75+
}
6076
}
6177

6278
template <typename T, typename Context>
@@ -148,31 +164,52 @@ void SoftmaxGradKernel(const Context& dev_ctx,
148164
dev_ctx.template Alloc<T>(x_grad);
149165
int64_t dim = static_cast<int64_t>(axis);
150166

167+
phi::DenseTensor x_trans(out_grad);
168+
if (out_grad.dims().size() == 0) {
169+
phi::DenseTensorMeta meta_1 = {out_grad.dtype(), phi::make_ddim({1})};
170+
x_trans.set_meta(meta_1);
171+
}
172+
151173
phi::DenseTensor cast_x;
152-
if (out_grad.dtype() == phi::DataType::FLOAT64) {
153-
phi::DenseTensorMeta meta(out_grad.meta());
174+
if (x_trans.dtype() == phi::DataType::FLOAT64) {
175+
phi::DenseTensorMeta meta(x_trans.meta());
154176
meta.dtype = phi::DataType::FLOAT32;
155177
cast_x.set_meta(meta);
156178

157179
custom_kernel::CastKernel<T, Context>(
158-
dev_ctx, out_grad, phi::DataType::FLOAT32, &cast_x);
180+
dev_ctx, x_trans, phi::DataType::FLOAT32, &cast_x);
159181
} else {
160-
cast_x = out_grad;
182+
cast_x = x_trans;
183+
}
184+
185+
phi::DenseTensor y_trans(out);
186+
if (out.dims().size() == 0) {
187+
phi::DenseTensorMeta meta_1 = {out.dtype(), phi::make_ddim({1})};
188+
y_trans.set_meta(meta_1);
161189
}
162190

163191
phi::DenseTensor cast_y;
164-
if (out.dtype() == phi::DataType::FLOAT64) {
165-
phi::DenseTensorMeta meta(out.meta());
192+
if (y_trans.dtype() == phi::DataType::FLOAT64) {
193+
phi::DenseTensorMeta meta(y_trans.meta());
166194
meta.dtype = phi::DataType::FLOAT32;
167195
cast_y.set_meta(meta);
168196

169197
custom_kernel::CastKernel<T, Context>(
170-
dev_ctx, out, phi::DataType::FLOAT32, &cast_y);
198+
dev_ctx, y_trans, phi::DataType::FLOAT32, &cast_y);
171199
} else {
172-
cast_y = out;
200+
cast_y = y_trans;
201+
}
202+
203+
auto x_grad_dims = x_grad->dims();
204+
if (x_grad_dims.size() == 0) {
205+
x_grad->Resize(phi::make_ddim({1}));
173206
}
174207

175208
EXEC_NPU_CMD(aclnnSoftmaxBackward, dev_ctx, cast_x, cast_y, dim, *x_grad);
209+
210+
if (x_grad_dims.size() == 0) {
211+
x_grad->Resize(x_grad_dims);
212+
}
176213
}
177214

178215
} // namespace custom_kernel

backends/npu/tests/unittests/test_kldiv_loss_op_npu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_check_grad(self):
7474
["X"],
7575
"Loss",
7676
no_grad_set=set(["Target"]),
77-
max_relative_error=0.15,
77+
max_relative_error=0.2,
7878
)
7979

8080
def initTestCase(self):
@@ -99,7 +99,7 @@ def test_check_grad(self):
9999
["X"],
100100
"Loss",
101101
no_grad_set=set(["Target"]),
102-
max_relative_error=0.16,
102+
max_relative_error=0.2,
103103
)
104104

105105

backends/npu/tests/unittests/test_sequence_mask_op_npu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,8 @@ def test_dygraph_api(self):
121121
for r in [out]:
122122
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
123123
paddle.enable_static()
124+
125+
126+
if __name__ == "__main__":
127+
paddle.enable_static()
128+
unittest.main()
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
disable_ut_npu
22
test_check_nan_inf_op_npu
33
test_conv3d_op_npu
4-
test_fused_matmul_bias_op_npu
5-
test_zero_dim_tensor_npu
64
test_matmulv2_op_npu

0 commit comments

Comments
 (0)