Skip to content

Commit e760c4e

Browse files
【npu】fix cast and arange for npu (#1630)
1 parent ff885fc commit e760c4e

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

backends/npu/kernels/arange_kernel.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,6 @@ void AclopArangeKernel(const Context& dev_ctx,
114114
int64_t size = 0;
115115
GetSize(start_value, end_value, step_value, &size);
116116

117-
out->Resize(phi::make_ddim({size}));
118-
dev_ctx.template Alloc<T>(out);
119-
120117
std::vector<T> odata;
121118
T value = start_value;
122119
for (int64_t i = 0; i < size; ++i) {
@@ -136,7 +133,16 @@ void ArangeKernel(const Context& dev_ctx,
136133
DO_COMPATIBILITY(aclnnArange,
137134
(custom_kernel::AclopArangeKernel<T, Context>(
138135
dev_ctx, start, end, step, out)));
136+
T start_value = start.to<T>();
137+
T end_value = end.to<T>();
138+
T step_value = step.to<T>();
139+
140+
int64_t size = 0;
141+
GetSize(start_value, end_value, step_value, &size);
142+
143+
out->Resize(phi::make_ddim({size}));
139144
dev_ctx.template Alloc<T>(out);
145+
140146
EXEC_NPU_CMD(aclnnArange, dev_ctx, start, end, step, *out);
141147
}
142148

backends/npu/kernels/cast_kernel.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,20 @@ void CastKernel(const Context& dev_ctx,
9393
DO_COMPATIBILITY(
9494
aclnnCast,
9595
(custom_kernel::AclopCastKernel<T, Context>(dev_ctx, x, dtype, out)));
96-
96+
out->Resize((x.dims()));
9797
if (x.dtype() == dtype) {
98+
if (x.dims() == phi::make_ddim({-1})) {
99+
*out = x;
100+
return;
101+
}
98102
dev_ctx.template Alloc<T>(out);
99103
TensorCopy(dev_ctx, x, false, out);
100104
return;
101105
}
102-
106+
if (x.dims() == phi::make_ddim({-1})) {
107+
PADDLE_THROW(phi::errors::InvalidArgument(
108+
"canot cast tensor with unknown shape for diffrent dtype"));
109+
}
103110
int aclDtype = ConvertToNpuDtype(dtype);
104111

105112
if (dtype == phi::DataType::FLOAT32) {

0 commit comments

Comments
 (0)