Skip to content

Commit f787130

Browse files
[NPU] fix ut question: kldiv_loss. (#1359)
1 parent 18e95b9 commit f787130

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

backends/npu/kernels/kldiv_loss_kernel.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@ void KLDivLossKernel(const Context& dev_ctx,
2222
const phi::DenseTensor& x,
2323
const phi::DenseTensor& label,
2424
const std::string& reduction,
25+
bool log_target,
2526
phi::DenseTensor* out) {
27+
PADDLE_ENFORCE_EQ(
28+
log_target,
29+
false,
30+
phi::errors::InvalidArgument("PaddlePaddle does not support parameters "
31+
"log_target is true on the NPU."));
32+
2633
dev_ctx.template Alloc<T>(out);
2734

2835
auto stream = dev_ctx.stream();
@@ -49,7 +56,13 @@ void KLDivLossGradKernel(const Context& dev_ctx,
4956
const phi::DenseTensor& label,
5057
const phi::DenseTensor& d_out,
5158
const std::string& reduction,
59+
bool log_target,
5260
phi::DenseTensor* d_x) {
61+
PADDLE_ENFORCE_EQ(
62+
log_target,
63+
false,
64+
phi::errors::InvalidArgument("PaddlePaddle does not support parameters "
65+
"log_target is true on the NPU."));
5366
dev_ctx.template Alloc<T>(d_x);
5467

5568
auto stream = dev_ctx.stream();

backends/npu/tools/disable_ut_npu_910b

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@ test_check_nan_inf_op_npu
33
test_conv3d_op_npu
44
test_contiguous_op_npu
55
test_fused_matmul_bias_op_npu
6-
test_kldiv_loss_op_npu
76
test_zero_dim_tensor_npu
87
test_matmulv2_op_npu

0 commit comments

Comments
 (0)