Skip to content

Commit 5f708f1

Browse files
[NPU] fix cann8.0.RC2 question. (#1365)
1 parent 5d83e87 commit 5f708f1

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

backends/npu/kernels/conv2d_kernel.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ void Conv2DGradKernel(const Context& dev_ctx,
431431
dev_ctx.template Alloc<T>(filter_grad);
432432
filter_grad_tensor = phi::DenseTensor(*filter_grad);
433433
} else {
434-
phi::DenseTensorMeta filter_grad_meta = {input.dtype(), input.dims()};
434+
phi::DenseTensorMeta filter_grad_meta = {filter.dtype(), filter.dims()};
435435
filter_grad_tensor.set_meta(filter_grad_meta);
436436
dev_ctx.template Alloc<T>(&filter_grad_tensor);
437437
}
@@ -445,8 +445,8 @@ void Conv2DGradKernel(const Context& dev_ctx,
445445
dev_ctx.template Alloc<T>(&input_grad_tensor);
446446
}
447447

448-
phi::DenseTensorMeta bias_grad_meta = {input.dtype(),
449-
phi::make_ddim({input.dims()[0]})};
448+
phi::DenseTensorMeta bias_grad_meta = {
449+
input.dtype(), phi::make_ddim({filter_grad_tensor.dims()[0]})};
450450
bias_grad_tensor.set_meta(bias_grad_meta);
451451
dev_ctx.template Alloc<T>(&bias_grad_tensor);
452452

backends/npu/kernels/conv_kernel.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ void DepthwiseConv2dGradKernel(const Context& dev_ctx,
416416
dev_ctx.template Alloc<T>(filter_grad);
417417
filter_grad_tensor = phi::DenseTensor(*filter_grad);
418418
} else {
419-
phi::DenseTensorMeta filter_grad_meta = {input.dtype(), input.dims()};
419+
phi::DenseTensorMeta filter_grad_meta = {filter.dtype(), filter.dims()};
420420
filter_grad_tensor.set_meta(filter_grad_meta);
421421
dev_ctx.template Alloc<T>(&filter_grad_tensor);
422422
}
@@ -430,8 +430,8 @@ void DepthwiseConv2dGradKernel(const Context& dev_ctx,
430430
dev_ctx.template Alloc<T>(&input_grad_tensor);
431431
}
432432

433-
phi::DenseTensorMeta bias_grad_meta = {input.dtype(),
434-
phi::make_ddim({input.dims()[0]})};
433+
phi::DenseTensorMeta bias_grad_meta = {
434+
input.dtype(), phi::make_ddim({filter_grad_tensor.dims()[0]})};
435435
bias_grad_tensor.set_meta(bias_grad_meta);
436436
dev_ctx.template Alloc<T>(&bias_grad_tensor);
437437

0 commit comments

Comments
 (0)