Skip to content

Commit 8c20c2a

Browse files
authored
[MLU] fix arange (#1296)
1 parent a874561 commit 8c20c2a

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

backends/mlu/kernels/range_kernel.cc

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,21 @@
1616

1717
namespace custom_kernel {
1818

19+
template <typename T>
20+
T GetValue(const phi::DenseTensor* x) {
21+
T value = static_cast<T>(0);
22+
if (x->place().GetType() != phi::AllocationType::CPU) {
23+
phi::DenseTensor cpu_x{};
24+
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
25+
phi::DeviceContext* dev_ctx = pool.Get(x->place());
26+
phi::Copy(*dev_ctx, *x, phi::CPUPlace(), true, &cpu_x);
27+
value = cpu_x.data<T>()[0];
28+
} else {
29+
value = x->data<T>()[0];
30+
}
31+
return value;
32+
}
33+
1934
template <typename T, typename Context>
2035
void ArangeTensorKernel(const Context& dev_ctx,
2136
const phi::DenseTensor& start_t,
@@ -25,11 +40,11 @@ void ArangeTensorKernel(const Context& dev_ctx,
2540
T* h_start_ptr = nullptr;
2641
T* h_end_ptr = nullptr;
2742
T* h_step_ptr = nullptr;
28-
43+
T start_value, end_value, step_value;
2944
if (start_t.place().GetType() == phi::AllocationType::CPU) { // tensor at CPU
30-
h_start_ptr = reinterpret_cast<T*>(const_cast<void*>(GetBasePtr(&start_t)));
31-
h_end_ptr = reinterpret_cast<T*>(const_cast<void*>(GetBasePtr(&end_t)));
32-
h_step_ptr = reinterpret_cast<T*>(const_cast<void*>(GetBasePtr(&step_t)));
45+
start_value = GetValue<T, Context>(dev_ctx, start_t);
46+
end_value = GetValue<T, Context>(dev_ctx, end_t);
47+
step_value = GetValue<T, Context>(dev_ctx, step_t);
3348
} else {
3449
phi::DenseTensor n;
3550
n.Resize(start_t.dims());
@@ -40,12 +55,11 @@ void ArangeTensorKernel(const Context& dev_ctx,
4055
h_end_ptr = new T(n_data[0]);
4156
TensorCopy(dev_ctx, step_t, true, &n, phi::CPUPlace());
4257
h_step_ptr = new T(n_data[0]);
58+
start_value = h_start_ptr[0];
59+
end_value = h_end_ptr[0];
60+
step_value = h_step_ptr[0];
4361
}
4462

45-
T start_value = h_start_ptr[0];
46-
T end_value = h_end_ptr[0];
47-
T step_value = h_step_ptr[0];
48-
4963
ArangeRawKernel<T>(dev_ctx, start_value, end_value, step_value, out);
5064
if (start_t.place().GetType() != phi::AllocationType::CPU) {
5165
delete h_start_ptr;

0 commit comments

Comments
 (0)