Skip to content

Commit ac80090

Browse files
authored
[INTEL_HPU] implementation of fused_index_select (#1884)
1 parent 34c91f5 commit ac80090

File tree

1 file changed

+261
-0
lines changed

1 file changed

+261
-0
lines changed
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "habanalabs/perf_lib_layer_params.h"
16+
#include "kernels/funcs.h"
17+
#include "kernels/hpu_funcs.h"
18+
#include "kernels/hpu_operator.h"
19+
#include "paddle/extension.h"
20+
#include "utils/utils.h"
21+
22+
namespace custom_kernel {
23+
24+
class FusedIndexSelectKernel : public HpuFusedOperator {
25+
public:
26+
FusedIndexSelectKernel() : HpuFusedOperator("fused_index_select_fwd") {}
27+
28+
void AddNode(ConvertTensors& ct, int batch_size) {
29+
auto inputs = ct.GetTensors();
30+
auto outputs = ct.GetTensors(false);
31+
auto num = outputs.size();
32+
33+
auto index = createTensorFromCT(&ct, num);
34+
auto index_length = inputs[num].dims[0];
35+
36+
for (decltype(num) i = 0; i < num; i++) {
37+
auto x = createTensorFromCT(&ct, i);
38+
synTensor out = nullptr;
39+
bool need_padding = (batch_size != index_length);
40+
std::string i_str = std::to_string(i);
41+
42+
// output padding is needed by index_select when index length < bsz
43+
if (need_padding) {
44+
std::string out_name = "tmp_out" + i_str;
45+
std::vector<int64_t> out_dims = {index_length, outputs[i].dims[1]};
46+
out = createTensorNoPresist(out_name, outputs[i].type, out_dims);
47+
} else {
48+
out = createTensorFromCT(&ct, i, false);
49+
}
50+
51+
ns_GatherKernel::Params params;
52+
params.axis = static_cast<int32_t>(inputs[i].dims.size()) - 1;
53+
54+
std::vector<synTensor> ins = {x, index};
55+
std::vector<synTensor> outs = {out};
56+
std::string node_name = "group_index_select_" + i_str;
57+
58+
switch (inputs[i].type) {
59+
case syn_type_fixed:
60+
AddNodeIndexSelect<bool>(ins, outs, params, node_name);
61+
break;
62+
case syn_type_single:
63+
AddNodeIndexSelect<float>(ins, outs, params, node_name);
64+
break;
65+
case syn_type_int32:
66+
AddNodeIndexSelect<int32_t>(ins, outs, params, node_name);
67+
break;
68+
case syn_type_int64:
69+
AddNodeIndexSelect<int64_t>(ins, outs, params, node_name);
70+
break;
71+
default:
72+
PD_CHECK(false,
73+
"[RUNTIME] unexpected x type encountered in "
74+
"FusedIndexSelect for AddNodeIndexSelect");
75+
break;
76+
}
77+
78+
if (need_padding) {
79+
std::string pad_name = "pad_" + i_str;
80+
std::vector<int64_t> pad_dims = {batch_size - inputs[num].dims[0],
81+
inputs[i].dims[1]};
82+
auto pad = createTensorNoPresist(pad_name, inputs[i].type, pad_dims);
83+
ns_ConstantKernel::Params const_params;
84+
const_params.constant.i = 0;
85+
std::vector<synTensor> full_out = {pad};
86+
std::string full_node_name = "full_zeros_" + i_str;
87+
88+
switch (inputs[i].type) {
89+
case syn_type_fixed:
90+
AddNodeFull<bool>(full_out, const_params, full_node_name);
91+
break;
92+
case syn_type_single:
93+
AddNodeFull<float>(full_out, const_params, full_node_name);
94+
break;
95+
case syn_type_int32:
96+
AddNodeFull<int32_t>(full_out, const_params, full_node_name);
97+
break;
98+
case syn_type_int64:
99+
AddNodeFull<int64_t>(full_out, const_params, full_node_name);
100+
break;
101+
default:
102+
PD_CHECK(false,
103+
"[RUNTIME] unexpected x type encountered in "
104+
"FusedIndexSelect for AddNodeFull");
105+
break;
106+
}
107+
108+
std::vector<synTensor> concat_ins = {out, pad};
109+
auto padding_out = createTensorFromCT(&ct, i, false);
110+
std::vector<synTensor> concat_outs = {padding_out};
111+
synConcatenateParams concatParams;
112+
concatParams.axis = 1;
113+
std::string concat_node_name = "concat_" + i_str;
114+
AddNodeConcat(concat_ins, concat_outs, concatParams, concat_node_name);
115+
}
116+
}
117+
118+
PD_CHECK(inputs.size() == (outputs.size() + 1),
119+
"[RUNTIME] in and out tensor numbers don't match");
120+
}
121+
};
122+
123+
} // namespace custom_kernel
124+
125+
#define DEF_DENSE_TENSOR_IN_AND_OUT(name) \
126+
auto name##_t = static_cast<const phi::DenseTensor*>(name.impl().get()); \
127+
ct.Add(name##_t); \
128+
std::shared_ptr<phi::DenseTensor> name##_o_t = \
129+
std::make_shared<phi::DenseTensor>(); \
130+
name##_o_t->Resize(phi::make_ddim({batch_size, name.dims()[1]})); \
131+
dev_ctx->Alloc(name##_o_t.get(), name.dtype()); \
132+
ct.Add(name##_o_t.get(), false); \
133+
results.push_back(paddle::Tensor(name##_o_t));
134+
135+
std::vector<paddle::Tensor> FusedIndexSelectForward(
136+
const paddle::Tensor& temperature,
137+
const paddle::Tensor& top_p,
138+
const paddle::Tensor& step_index,
139+
const paddle::Tensor& prompt_token_idx,
140+
const paddle::Tensor& pre_token_ids,
141+
const paddle::Tensor& stop_flags,
142+
const paddle::Tensor& seq_lens_encoder,
143+
const paddle::Tensor& seq_lens_decoder,
144+
const paddle::Tensor& frequency_penalties,
145+
const paddle::Tensor& presence_penalties,
146+
const paddle::Tensor& repetition_penalties,
147+
const paddle::Tensor& min_dec_lens,
148+
const paddle::Tensor& sampled_ids,
149+
const int batch_size) {
150+
auto dev_ctx = static_cast<const phi::CustomContext*>(
151+
paddle::experimental::DeviceContextPool::Instance().Get(
152+
temperature.place()));
153+
154+
custom_kernel::ConvertTensors ct;
155+
std::vector<paddle::Tensor> results;
156+
157+
DEF_DENSE_TENSOR_IN_AND_OUT(temperature);
158+
DEF_DENSE_TENSOR_IN_AND_OUT(top_p);
159+
DEF_DENSE_TENSOR_IN_AND_OUT(step_index);
160+
DEF_DENSE_TENSOR_IN_AND_OUT(prompt_token_idx);
161+
DEF_DENSE_TENSOR_IN_AND_OUT(pre_token_ids);
162+
DEF_DENSE_TENSOR_IN_AND_OUT(stop_flags);
163+
DEF_DENSE_TENSOR_IN_AND_OUT(seq_lens_encoder);
164+
DEF_DENSE_TENSOR_IN_AND_OUT(seq_lens_decoder);
165+
DEF_DENSE_TENSOR_IN_AND_OUT(frequency_penalties);
166+
DEF_DENSE_TENSOR_IN_AND_OUT(presence_penalties);
167+
DEF_DENSE_TENSOR_IN_AND_OUT(repetition_penalties);
168+
DEF_DENSE_TENSOR_IN_AND_OUT(min_dec_lens);
169+
170+
auto sampled_ids_t =
171+
static_cast<const phi::DenseTensor*>(sampled_ids.impl().get());
172+
ct.Add(sampled_ids_t);
173+
174+
std::vector<DIMS> inputs_dims = ct.GetDims();
175+
std::vector<DIMS> outputs_dims = ct.GetDims(false);
176+
inputs_dims.insert(
177+
inputs_dims.end(), outputs_dims.begin(), outputs_dims.end());
178+
OpCacheOperator op_info;
179+
op_info.prepareOpInfo<float, nullptr_t>(
180+
"FusedIndexSelectKernel", inputs_dims, nullptr);
181+
auto recipe = op_info.GetRecipe();
182+
183+
if (recipe == nullptr) {
184+
custom_kernel::FusedIndexSelectKernel op;
185+
op.AddNode(ct, batch_size);
186+
op.Compile();
187+
op_info.setOp(op);
188+
recipe = op_info.GetRecipe();
189+
}
190+
191+
std::map<std::string, uint64_t> tensors = ct.GetDeviceAddr();
192+
RecipeRunner runner(recipe);
193+
runner.Run(reinterpret_cast<C_Stream>(dev_ctx->stream()), tensors);
194+
195+
return results;
196+
}
197+
198+
std::vector<std::vector<int64_t>> FusedIndexSelectInferShape(
199+
const std::vector<int64_t>& temperature_shape,
200+
const std::vector<int64_t>& top_p_shape,
201+
const std::vector<int64_t>& step_index_shape,
202+
const std::vector<int64_t>& prompt_token_idx_shape,
203+
const std::vector<int64_t>& pre_token_ids_shape,
204+
const std::vector<int64_t>& stop_flags_shape,
205+
const std::vector<int64_t>& seq_lens_encoder_shape,
206+
const std::vector<int64_t>& seq_lens_decoder_shape,
207+
const std::vector<int64_t>& frequency_penalties_shape,
208+
const std::vector<int64_t>& presence_penalties_shape,
209+
const std::vector<int64_t>& repetition_penalties_shape,
210+
const std::vector<int64_t>& min_dec_lens_shape,
211+
const std::vector<int64_t>& sampled_ids_shape) {
212+
return {temperature_shape};
213+
}
214+
215+
std::vector<paddle::DataType> FusedIndexSelectInferDtype(
216+
const paddle::DataType& temperature_dtype,
217+
const paddle::DataType& top_p_dtype,
218+
const paddle::DataType& step_index_dtype,
219+
const paddle::DataType& prompt_token_idx_dtype,
220+
const paddle::DataType& pre_token_ids_dtype,
221+
const paddle::DataType& stop_flags_dtype,
222+
const paddle::DataType& seq_lens_encoder_dtype,
223+
const paddle::DataType& seq_lens_decoder_dtype,
224+
const paddle::DataType& frequency_penalties_dtype,
225+
const paddle::DataType& presence_penalties_dtype,
226+
const paddle::DataType& repetition_penalties_dtype,
227+
const paddle::DataType& min_dec_lens_dtype,
228+
const paddle::DataType& sampled_ids_dtype) {
229+
return {temperature_dtype};
230+
}
231+
232+
PD_BUILD_OP(fused_index_select)
233+
.Inputs({"temperature",
234+
"top_p",
235+
"step_index",
236+
"prompt_token_idx",
237+
"pre_token_ids",
238+
"stop_flags",
239+
"seq_lens_encoder",
240+
"seq_lens_decoder",
241+
"frequency_penalties",
242+
"presence_penalties",
243+
"repetition_penalties",
244+
"min_dec_lens",
245+
"sampled_ids"})
246+
.Outputs({"temperature_ext",
247+
"top_p_ext",
248+
"step_index_ext",
249+
"prompt_token_idx_ext",
250+
"pre_token_ids_ext",
251+
"stop_flags_ext",
252+
"seq_lens_encoder_ext",
253+
"seq_lens_decoder_ext",
254+
"frequency_penalties_ext",
255+
"presence_penalties_ext",
256+
"repetition_penalties_ext",
257+
"min_dec_lens_ext"})
258+
.Attrs({"batch_size: int"})
259+
.SetKernelFn(PD_KERNEL(FusedIndexSelectForward))
260+
.SetInferShapeFn(PD_INFER_SHAPE(FusedIndexSelectInferShape))
261+
.SetInferDtypeFn(PD_INFER_DTYPE(FusedIndexSelectInferDtype));

0 commit comments

Comments
 (0)