@@ -233,53 +233,49 @@ void SetTensorValueKernel(const Context& dev_ctx,
233
233
runner.Run (reinterpret_cast <C_Stream>(dev_ctx.stream ()), tensors);
234
234
}
235
235
236
- // template <typename T, typename Context>
237
- // void SetValueKernel(const Context& dev_ctx,
238
- // const phi::DenseTensor& x,
239
- // const phi::IntArray& starts,
240
- // const phi::IntArray& ends,
241
- // const phi::IntArray& steps,
242
- // const std::vector<int64_t>& axes,
243
- // const std::vector<int64_t>& decrease_axes,
244
- // const std::vector<int64_t>& none_axes,
245
- // const std::vector<int64_t>& shape,
246
- // const std::vector<phi::Scalar>& values,
247
- // phi::DenseTensor* out) {
248
- // std::vector<T> assgin_values;
249
- // assgin_values.reserve(values.size());
250
- // for (const auto& val : values) {
251
- // assgin_values.push_back(val.to<T>());
252
- // }
253
- // phi::DenseTensor value_tensor;
254
- // value_tensor.Resize(phi::make_ddim(shape));
255
- // custom_kernel::TensorFromVector(
256
- // dev_ctx, assgin_values, dev_ctx, &value_tensor);
257
- // value_tensor.Resize(phi::make_ddim(shape));
258
-
259
- // custom_kernel::SetTensorValueKernel<T, Context>(dev_ctx,
260
- // x,
261
- // value_tensor,
262
- // starts,
263
- // ends,
264
- // steps,
265
- // axes,
266
- // decrease_axes,
267
- // none_axes,
268
- // out);
269
- // }
236
+ template <typename T, typename Context>
237
+ void SetValueKernel (const Context& dev_ctx,
238
+ const phi::DenseTensor& x,
239
+ const phi::IntArray& starts,
240
+ const phi::IntArray& ends,
241
+ const phi::IntArray& steps,
242
+ const std::vector<int64_t >& axes,
243
+ const std::vector<int64_t >& decrease_axes,
244
+ const std::vector<int64_t >& none_axes,
245
+ const std::vector<int64_t >& shape,
246
+ const std::vector<phi::Scalar>& values,
247
+ phi::DenseTensor* out) {
248
+ std::vector<T> assign_values;
249
+ assign_values.reserve (values.size ());
250
+ for (const auto & val : values) {
251
+ assign_values.push_back (val.to <T>());
252
+ }
253
+ phi::DenseTensor value_tensor;
254
+ value_tensor.Resize (phi::make_ddim (shape));
255
+ TensorFromVector (dev_ctx, assign_values, dev_ctx, &value_tensor);
256
+ value_tensor.Resize (phi::make_ddim (shape));
270
257
271
- //
258
+ custom_kernel::SetTensorValueKernel<T, Context>(dev_ctx,
259
+ x,
260
+ value_tensor,
261
+ starts,
262
+ ends,
263
+ steps,
264
+ axes,
265
+ decrease_axes,
266
+ none_axes,
267
+ out);
268
+ }
272
269
273
270
} // namespace custom_kernel
274
271
275
- // PD_REGISTER_PLUGIN_KERNEL(set_value,
276
- // intel_hpu,
277
- // ALL_LAYOUT,
278
- // custom_kernel::SetValueKernel,
279
- // float,
280
- // phi::dtype::float16,
281
- // phi::dtype::bfloat16) {
282
- // }
272
+ PD_REGISTER_PLUGIN_KERNEL (set_value,
273
+ intel_hpu,
274
+ ALL_LAYOUT,
275
+ custom_kernel::SetValueKernel,
276
+ float ,
277
+ phi::dtype::float16,
278
+ phi::dtype::bfloat16) {}
283
279
284
280
PD_REGISTER_PLUGIN_KERNEL (set_value_with_tensor,
285
281
intel_hpu,
0 commit comments