16
16
17
17
namespace custom_kernel {
18
18
19
+ template <typename T>
20
+ T GetValue (const phi::DenseTensor* x) {
21
+ T value = static_cast <T>(0 );
22
+ if (x->place ().GetType () != phi::AllocationType::CPU) {
23
+ phi::DenseTensor cpu_x{};
24
+ phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance ();
25
+ phi::DeviceContext* dev_ctx = pool.Get (x->place ());
26
+ phi::Copy (*dev_ctx, *x, phi::CPUPlace (), true , &cpu_x);
27
+ value = cpu_x.data <T>()[0 ];
28
+ } else {
29
+ value = x->data <T>()[0 ];
30
+ }
31
+ return value;
32
+ }
33
+
19
34
template <typename T, typename Context>
20
35
void ArangeTensorKernel (const Context& dev_ctx,
21
36
const phi::DenseTensor& start_t ,
@@ -25,11 +40,11 @@ void ArangeTensorKernel(const Context& dev_ctx,
25
40
T* h_start_ptr = nullptr ;
26
41
T* h_end_ptr = nullptr ;
27
42
T* h_step_ptr = nullptr ;
28
-
43
+ T start_value, end_value, step_value;
29
44
if (start_t .place ().GetType () == phi::AllocationType::CPU) { // tensor at CPU
30
- h_start_ptr = reinterpret_cast <T*>( const_cast < void *>( GetBasePtr (& start_t )) );
31
- h_end_ptr = reinterpret_cast <T*>( const_cast < void *>( GetBasePtr (& end_t )) );
32
- h_step_ptr = reinterpret_cast <T*>( const_cast < void *>( GetBasePtr (& step_t )) );
45
+ start_value = GetValue<T, Context>(dev_ctx, start_t );
46
+ end_value = GetValue<T, Context>(dev_ctx, end_t );
47
+ step_value = GetValue<T, Context>(dev_ctx, step_t );
33
48
} else {
34
49
phi::DenseTensor n;
35
50
n.Resize (start_t .dims ());
@@ -40,12 +55,11 @@ void ArangeTensorKernel(const Context& dev_ctx,
40
55
h_end_ptr = new T (n_data[0 ]);
41
56
TensorCopy (dev_ctx, step_t , true , &n, phi::CPUPlace ());
42
57
h_step_ptr = new T (n_data[0 ]);
58
+ start_value = h_start_ptr[0 ];
59
+ end_value = h_end_ptr[0 ];
60
+ step_value = h_step_ptr[0 ];
43
61
}
44
62
45
- T start_value = h_start_ptr[0 ];
46
- T end_value = h_end_ptr[0 ];
47
- T step_value = h_step_ptr[0 ];
48
-
49
63
ArangeRawKernel<T>(dev_ctx, start_value, end_value, step_value, out);
50
64
if (start_t .place ().GetType () != phi::AllocationType::CPU) {
51
65
delete h_start_ptr;
0 commit comments