Skip to content

Commit bd66cdf

Browse files
authored
[INTEL_HPU] Add set value OP (#1852)
Signed-off-by: Fei Wang <[email protected]>
1 parent 4a73fc3 commit bd66cdf

File tree

2 files changed

+67
-43
lines changed

2 files changed

+67
-43
lines changed

backends/intel_hpu/kernels/funcs.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,34 @@ inline paddle::Tensor copy_tensor_wrapper(const phi::CustomContext* dev_ctx,
206206
return paddle::Tensor(dst_dt);
207207
}
208208

209+
/**
210+
* CPU -> INTEL_HPU
211+
*/
212+
template <typename T>
213+
inline void TensorFromVector(const phi::CustomContext& ctx,
214+
const std::vector<T>& src,
215+
const phi::CustomContext& dev_ctx,
216+
phi::DenseTensor* dst) {
217+
auto dst_place = dev_ctx.GetPlace();
218+
C_Device_st device{dst_place.GetDeviceId()};
219+
auto src_ptr = static_cast<const void*>(src.data());
220+
dst->Resize({static_cast<int64_t>(src.size())});
221+
auto dst_ptr = static_cast<void*>(dev_ctx.template Alloc<T>(dst));
222+
auto size = src.size() * sizeof(T);
223+
if (UNLIKELY(size == 0)) return;
224+
225+
if (dst_place.GetType() == phi::AllocationType::CUSTOM) {
226+
AsyncMemCpyH2D(&device,
227+
static_cast<C_Stream>(dev_ctx.stream()),
228+
dst_ptr,
229+
src_ptr,
230+
size);
231+
} else {
232+
PADDLE_THROW(phi::errors::Unimplemented(
233+
"TensorFromVector on %s is not supported.", dst_place));
234+
}
235+
}
236+
209237
inline int CanonicalAxis(const int axis, const int rank) {
210238
if (axis < 0) {
211239
return axis + rank;

backends/intel_hpu/kernels/set_value_kernel.cc

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -233,53 +233,49 @@ void SetTensorValueKernel(const Context& dev_ctx,
233233
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors);
234234
}
235235

236-
// template <typename T, typename Context>
237-
// void SetValueKernel(const Context& dev_ctx,
238-
// const phi::DenseTensor& x,
239-
// const phi::IntArray& starts,
240-
// const phi::IntArray& ends,
241-
// const phi::IntArray& steps,
242-
// const std::vector<int64_t>& axes,
243-
// const std::vector<int64_t>& decrease_axes,
244-
// const std::vector<int64_t>& none_axes,
245-
// const std::vector<int64_t>& shape,
246-
// const std::vector<phi::Scalar>& values,
247-
// phi::DenseTensor* out) {
248-
// std::vector<T> assgin_values;
249-
// assgin_values.reserve(values.size());
250-
// for (const auto& val : values) {
251-
// assgin_values.push_back(val.to<T>());
252-
// }
253-
// phi::DenseTensor value_tensor;
254-
// value_tensor.Resize(phi::make_ddim(shape));
255-
// custom_kernel::TensorFromVector(
256-
// dev_ctx, assgin_values, dev_ctx, &value_tensor);
257-
// value_tensor.Resize(phi::make_ddim(shape));
258-
259-
// custom_kernel::SetTensorValueKernel<T, Context>(dev_ctx,
260-
// x,
261-
// value_tensor,
262-
// starts,
263-
// ends,
264-
// steps,
265-
// axes,
266-
// decrease_axes,
267-
// none_axes,
268-
// out);
269-
// }
236+
template <typename T, typename Context>
237+
void SetValueKernel(const Context& dev_ctx,
238+
const phi::DenseTensor& x,
239+
const phi::IntArray& starts,
240+
const phi::IntArray& ends,
241+
const phi::IntArray& steps,
242+
const std::vector<int64_t>& axes,
243+
const std::vector<int64_t>& decrease_axes,
244+
const std::vector<int64_t>& none_axes,
245+
const std::vector<int64_t>& shape,
246+
const std::vector<phi::Scalar>& values,
247+
phi::DenseTensor* out) {
248+
std::vector<T> assign_values;
249+
assign_values.reserve(values.size());
250+
for (const auto& val : values) {
251+
assign_values.push_back(val.to<T>());
252+
}
253+
phi::DenseTensor value_tensor;
254+
value_tensor.Resize(phi::make_ddim(shape));
255+
TensorFromVector(dev_ctx, assign_values, dev_ctx, &value_tensor);
256+
value_tensor.Resize(phi::make_ddim(shape));
270257

271-
//
258+
custom_kernel::SetTensorValueKernel<T, Context>(dev_ctx,
259+
x,
260+
value_tensor,
261+
starts,
262+
ends,
263+
steps,
264+
axes,
265+
decrease_axes,
266+
none_axes,
267+
out);
268+
}
272269

273270
} // namespace custom_kernel
274271

275-
// PD_REGISTER_PLUGIN_KERNEL(set_value,
276-
// intel_hpu,
277-
// ALL_LAYOUT,
278-
// custom_kernel::SetValueKernel,
279-
// float,
280-
// phi::dtype::float16,
281-
// phi::dtype::bfloat16) {
282-
// }
272+
PD_REGISTER_PLUGIN_KERNEL(set_value,
273+
intel_hpu,
274+
ALL_LAYOUT,
275+
custom_kernel::SetValueKernel,
276+
float,
277+
phi::dtype::float16,
278+
phi::dtype::bfloat16) {}
283279

284280
PD_REGISTER_PLUGIN_KERNEL(set_value_with_tensor,
285281
intel_hpu,

0 commit comments

Comments
 (0)