Skip to content

Commit 7d50a37

Browse files
【NPU】 fix CANN_930 question (#1419)
1 parent 8a560ba commit 7d50a37

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

backends/npu/kernels/clip_by_norm_kernel.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,15 @@ void NormKernel(const Context& dev_ctx,
5858

5959
phi::Scalar p = 2.0f;
6060
const auto& x_dims = x.dims();
61-
std::vector<int64_t> axis;
61+
std::vector<int64_t> axis, resize_list;
6262
for (int64_t i = 0; i < x_dims.size(); ++i) {
6363
axis.push_back(i);
64+
resize_list.push_back(1);
6465
}
65-
bool keepdim = false;
66+
bool keepdim = true;
67+
x_norm->Resize(phi::make_ddim(resize_list));
6668
EXEC_NPU_CMD(aclnnNorm, dev_ctx, x, p, axis, keepdim, *x_norm);
69+
x_norm->Resize({1});
6770
}
6871

6972
template <typename T, typename Context>

backends/npu/kernels/concat_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ void ConcatGradKernel(const Context& dev_ctx,
215215
ends_array.push_back(ins[j]->dims()[axis] + offset);
216216

217217
std::vector<int64_t> steps;
218-
for (int i = 0; i < outs[j]->dims().size(); i++) {
218+
for (int i = 0; i < axes_t.size(); i++) {
219219
steps.push_back(1.0);
220220
}
221221

backends/npu/kernels/slice_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ void SliceRawKernel(const Context& dev_ctx,
184184
out->Resize(phi::make_ddim(size));
185185

186186
std::vector<int64_t> steps;
187-
for (int i = 0; i < out->dims().size(); i++) {
187+
for (int i = 0; i < axes_t.size(); i++) {
188188
steps.push_back(1.0);
189189
}
190190

backends/npu/kernels/tile_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ void TileGradKernel(const Context& dev_ctx,
316316
for (int i = 0; i < out_grad.dims().size(); i++) {
317317
axes.push_back(i);
318318
}
319-
std::vector<int64_t> steps(out_grad.dims().size(), 1);
319+
std::vector<int64_t> steps(axes.size(), 1);
320320
static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray);
321321
auto starts_acl = aclCreateIntArray(starts.data(), starts.size());
322322
auto ends_acl = aclCreateIntArray(x_grad_dims.data(), x_grad_dims.size());

0 commit comments

Comments
 (0)