Skip to content

Commit a07fccd

Browse files
[INTEL_HPU] support index_put kernel (#1612)
1 parent 7dddf3e commit a07fccd

File tree

3 files changed

+717
-0
lines changed

3 files changed

+717
-0
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "habanalabs/perf_lib_layer_params.h"
16+
#include "kernels/funcs.h"
17+
#include "kernels/hpu_operator.h"
18+
#include "utils/utils.h"
19+
20+
namespace custom_kernel {
21+
22+
class IndexPut : public HpuOperator {
23+
public:
24+
IndexPut() : HpuOperator("scatter_nd_onnx_fwd_", false) {}
25+
26+
void AddNode(ConvertTensors& ct, bool is_inplace) {
27+
auto inputs = ct.GetTensors();
28+
auto outputs = ct.GetTensors(false);
29+
30+
std::vector<synTensor> syn_scatter_inputs;
31+
synSectionHandle section_shared = nullptr;
32+
if (is_inplace) section_shared = createSection();
33+
syn_scatter_inputs.push_back(createTensor(inputs[0].dims.size(),
34+
inputs[0].type,
35+
inputs[0].dims,
36+
true,
37+
inputs[0].name,
38+
section_shared));
39+
guid_ = guid_ + SynDataTypeToStr(inputs[0].type);
40+
41+
DIMS tmp = inputs[1].dims;
42+
std::vector<synTensor> syn_concat_inputs;
43+
for (size_t i = 1; i < inputs.size() - 1; i++) {
44+
// unsqueeze on dim0 directly [x, 1]
45+
inputs[i].dims.emplace_back(1);
46+
syn_concat_inputs.push_back(createTensor(inputs[i].dims.size(),
47+
inputs[i].type,
48+
inputs[i].dims,
49+
true,
50+
inputs[i].name));
51+
}
52+
// concat indices
53+
std::string guid_concat = "concat";
54+
std::string name_concat = guid_ + "concat";
55+
synConcatenateParams concatParams;
56+
concatParams.axis = 0;
57+
std::vector<synTensor> syn_concat_outputs;
58+
DIMS indices_concat_dims = inputs[1].dims;
59+
indices_concat_dims.back() = syn_concat_inputs.size();
60+
syn_concat_outputs.push_back(createTensor(indices_concat_dims.size(),
61+
inputs[1].type,
62+
indices_concat_dims,
63+
false,
64+
"indices_concat"));
65+
synStatus status = synNodeCreate(graphHandle_,
66+
syn_concat_inputs.data(),
67+
syn_concat_outputs.data(),
68+
syn_concat_inputs.size(),
69+
syn_concat_outputs.size(),
70+
&concatParams,
71+
sizeof(concatParams),
72+
guid_concat.c_str(),
73+
name_concat.c_str(),
74+
nullptr,
75+
nullptr);
76+
PD_CHECK(status == synSuccess,
77+
"[RUNTIME] IndexPutKernel synNodeCreate (concat) "
78+
"failed = ",
79+
status);
80+
81+
// add indices concat tensor
82+
syn_scatter_inputs.push_back(syn_concat_outputs[0]);
83+
// add update tensor
84+
syn_scatter_inputs.push_back(createTensor(inputs.back().dims.size(),
85+
inputs.back().type,
86+
inputs.back().dims,
87+
true,
88+
inputs.back().name));
89+
90+
std::vector<synTensor> syn_scatter_outputs;
91+
syn_scatter_outputs.push_back(createTensor(outputs[0].dims.size(),
92+
outputs[0].type,
93+
outputs[0].dims,
94+
true,
95+
outputs[0].name,
96+
section_shared));
97+
98+
status = synNodeCreate(graphHandle_,
99+
syn_scatter_inputs.data(),
100+
syn_scatter_outputs.data(),
101+
syn_scatter_inputs.size(),
102+
syn_scatter_outputs.size(),
103+
nullptr,
104+
0,
105+
guid_.c_str(),
106+
"index_put",
107+
nullptr,
108+
nullptr);
109+
PD_CHECK(
110+
status == synSuccess, "[RUNTIME] synNodeCreate () failed = %d", status);
111+
}
112+
};
113+
114+
template <typename T, typename Context>
115+
void IndexPutKernel(const Context& dev_ctx,
116+
const phi::DenseTensor& x,
117+
const std::vector<const phi::DenseTensor*>& indices,
118+
const phi::DenseTensor& value,
119+
bool accumulate,
120+
phi::DenseTensor* out) {
121+
PD_CHECK(accumulate == false,
122+
"IndexPutKernel doesn't support accumulate=true");
123+
124+
dev_ctx.template Alloc<T>(out);
125+
126+
ConvertTensors ct;
127+
ct.Add(x);
128+
for (const auto& index : indices) {
129+
ct.Add(index);
130+
}
131+
ct.Add(value);
132+
ct.Add(out, false);
133+
134+
bool is_inplace = (out->data() == x.data());
135+
136+
OpCacheOperator op_info;
137+
std::vector<DIMS> inputs_dims = ct.GetDims();
138+
op_info.prepareOpInfo<T, nullptr_t>(
139+
is_inplace ? "IndexPutKernel" : "_IndexPutKernel", inputs_dims, nullptr);
140+
auto recipe = op_info.GetRecipe();
141+
if (recipe == nullptr) {
142+
IndexPut op;
143+
144+
op.AddNode(ct, is_inplace);
145+
op.Compile();
146+
op_info.setOp(op);
147+
148+
recipe = op_info.GetRecipe();
149+
}
150+
151+
std::map<std::string, uint64_t> tensors = ct.GetDeviceAddr();
152+
RecipeRunner runner(recipe);
153+
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors);
154+
}
155+
156+
} // namespace custom_kernel
157+
158+
PD_REGISTER_PLUGIN_KERNEL(index_put,
159+
intel_hpu,
160+
ALL_LAYOUT,
161+
custom_kernel::IndexPutKernel,
162+
phi::dtype::float16,
163+
phi::dtype::bfloat16,
164+
float,
165+
int32_t,
166+
int64_t,
167+
bool) {}

0 commit comments

Comments
 (0)