|
| 1 | +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); you may |
| 4 | +// not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "habanalabs/perf_lib_layer_params.h" |
| 16 | +#include "kernels/funcs.h" |
| 17 | +#include "kernels/hpu_funcs.h" |
| 18 | +#include "kernels/hpu_operator.h" |
| 19 | +#include "paddle/extension.h" |
| 20 | +#include "utils/utils.h" |
| 21 | + |
| 22 | +namespace custom_kernel { |
| 23 | + |
| 24 | +class FusedIndexSelectKernel : public HpuFusedOperator { |
| 25 | + public: |
| 26 | + FusedIndexSelectKernel() : HpuFusedOperator("fused_index_select_fwd") {} |
| 27 | + |
| 28 | + void AddNode(ConvertTensors& ct, int batch_size) { |
| 29 | + auto inputs = ct.GetTensors(); |
| 30 | + auto outputs = ct.GetTensors(false); |
| 31 | + auto num = outputs.size(); |
| 32 | + |
| 33 | + auto index = createTensorFromCT(&ct, num); |
| 34 | + auto index_length = inputs[num].dims[0]; |
| 35 | + |
| 36 | + for (decltype(num) i = 0; i < num; i++) { |
| 37 | + auto x = createTensorFromCT(&ct, i); |
| 38 | + synTensor out = nullptr; |
| 39 | + bool need_padding = (batch_size != index_length); |
| 40 | + std::string i_str = std::to_string(i); |
| 41 | + |
| 42 | + // output padding is needed by index_select when index length < bsz |
| 43 | + if (need_padding) { |
| 44 | + std::string out_name = "tmp_out" + i_str; |
| 45 | + std::vector<int64_t> out_dims = {index_length, outputs[i].dims[1]}; |
| 46 | + out = createTensorNoPresist(out_name, outputs[i].type, out_dims); |
| 47 | + } else { |
| 48 | + out = createTensorFromCT(&ct, i, false); |
| 49 | + } |
| 50 | + |
| 51 | + ns_GatherKernel::Params params; |
| 52 | + params.axis = static_cast<int32_t>(inputs[i].dims.size()) - 1; |
| 53 | + |
| 54 | + std::vector<synTensor> ins = {x, index}; |
| 55 | + std::vector<synTensor> outs = {out}; |
| 56 | + std::string node_name = "group_index_select_" + i_str; |
| 57 | + |
| 58 | + switch (inputs[i].type) { |
| 59 | + case syn_type_fixed: |
| 60 | + AddNodeIndexSelect<bool>(ins, outs, params, node_name); |
| 61 | + break; |
| 62 | + case syn_type_single: |
| 63 | + AddNodeIndexSelect<float>(ins, outs, params, node_name); |
| 64 | + break; |
| 65 | + case syn_type_int32: |
| 66 | + AddNodeIndexSelect<int32_t>(ins, outs, params, node_name); |
| 67 | + break; |
| 68 | + case syn_type_int64: |
| 69 | + AddNodeIndexSelect<int64_t>(ins, outs, params, node_name); |
| 70 | + break; |
| 71 | + default: |
| 72 | + PD_CHECK(false, |
| 73 | + "[RUNTIME] unexpected x type encountered in " |
| 74 | + "FusedIndexSelect for AddNodeIndexSelect"); |
| 75 | + break; |
| 76 | + } |
| 77 | + |
| 78 | + if (need_padding) { |
| 79 | + std::string pad_name = "pad_" + i_str; |
| 80 | + std::vector<int64_t> pad_dims = {batch_size - inputs[num].dims[0], |
| 81 | + inputs[i].dims[1]}; |
| 82 | + auto pad = createTensorNoPresist(pad_name, inputs[i].type, pad_dims); |
| 83 | + ns_ConstantKernel::Params const_params; |
| 84 | + const_params.constant.i = 0; |
| 85 | + std::vector<synTensor> full_out = {pad}; |
| 86 | + std::string full_node_name = "full_zeros_" + i_str; |
| 87 | + |
| 88 | + switch (inputs[i].type) { |
| 89 | + case syn_type_fixed: |
| 90 | + AddNodeFull<bool>(full_out, const_params, full_node_name); |
| 91 | + break; |
| 92 | + case syn_type_single: |
| 93 | + AddNodeFull<float>(full_out, const_params, full_node_name); |
| 94 | + break; |
| 95 | + case syn_type_int32: |
| 96 | + AddNodeFull<int32_t>(full_out, const_params, full_node_name); |
| 97 | + break; |
| 98 | + case syn_type_int64: |
| 99 | + AddNodeFull<int64_t>(full_out, const_params, full_node_name); |
| 100 | + break; |
| 101 | + default: |
| 102 | + PD_CHECK(false, |
| 103 | + "[RUNTIME] unexpected x type encountered in " |
| 104 | + "FusedIndexSelect for AddNodeFull"); |
| 105 | + break; |
| 106 | + } |
| 107 | + |
| 108 | + std::vector<synTensor> concat_ins = {out, pad}; |
| 109 | + auto padding_out = createTensorFromCT(&ct, i, false); |
| 110 | + std::vector<synTensor> concat_outs = {padding_out}; |
| 111 | + synConcatenateParams concatParams; |
| 112 | + concatParams.axis = 1; |
| 113 | + std::string concat_node_name = "concat_" + i_str; |
| 114 | + AddNodeConcat(concat_ins, concat_outs, concatParams, concat_node_name); |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + PD_CHECK(inputs.size() == (outputs.size() + 1), |
| 119 | + "[RUNTIME] in and out tensor numbers don't match"); |
| 120 | + } |
| 121 | +}; |
| 122 | + |
| 123 | +} // namespace custom_kernel |
| 124 | + |
| 125 | +#define DEF_DENSE_TENSOR_IN_AND_OUT(name) \ |
| 126 | + auto name##_t = static_cast<const phi::DenseTensor*>(name.impl().get()); \ |
| 127 | + ct.Add(name##_t); \ |
| 128 | + std::shared_ptr<phi::DenseTensor> name##_o_t = \ |
| 129 | + std::make_shared<phi::DenseTensor>(); \ |
| 130 | + name##_o_t->Resize(phi::make_ddim({batch_size, name.dims()[1]})); \ |
| 131 | + dev_ctx->Alloc(name##_o_t.get(), name.dtype()); \ |
| 132 | + ct.Add(name##_o_t.get(), false); \ |
| 133 | + results.push_back(paddle::Tensor(name##_o_t)); |
| 134 | + |
| 135 | +std::vector<paddle::Tensor> FusedIndexSelectForward( |
| 136 | + const paddle::Tensor& temperature, |
| 137 | + const paddle::Tensor& top_p, |
| 138 | + const paddle::Tensor& step_index, |
| 139 | + const paddle::Tensor& prompt_token_idx, |
| 140 | + const paddle::Tensor& pre_token_ids, |
| 141 | + const paddle::Tensor& stop_flags, |
| 142 | + const paddle::Tensor& seq_lens_encoder, |
| 143 | + const paddle::Tensor& seq_lens_decoder, |
| 144 | + const paddle::Tensor& frequency_penalties, |
| 145 | + const paddle::Tensor& presence_penalties, |
| 146 | + const paddle::Tensor& repetition_penalties, |
| 147 | + const paddle::Tensor& min_dec_lens, |
| 148 | + const paddle::Tensor& sampled_ids, |
| 149 | + const int batch_size) { |
| 150 | + auto dev_ctx = static_cast<const phi::CustomContext*>( |
| 151 | + paddle::experimental::DeviceContextPool::Instance().Get( |
| 152 | + temperature.place())); |
| 153 | + |
| 154 | + custom_kernel::ConvertTensors ct; |
| 155 | + std::vector<paddle::Tensor> results; |
| 156 | + |
| 157 | + DEF_DENSE_TENSOR_IN_AND_OUT(temperature); |
| 158 | + DEF_DENSE_TENSOR_IN_AND_OUT(top_p); |
| 159 | + DEF_DENSE_TENSOR_IN_AND_OUT(step_index); |
| 160 | + DEF_DENSE_TENSOR_IN_AND_OUT(prompt_token_idx); |
| 161 | + DEF_DENSE_TENSOR_IN_AND_OUT(pre_token_ids); |
| 162 | + DEF_DENSE_TENSOR_IN_AND_OUT(stop_flags); |
| 163 | + DEF_DENSE_TENSOR_IN_AND_OUT(seq_lens_encoder); |
| 164 | + DEF_DENSE_TENSOR_IN_AND_OUT(seq_lens_decoder); |
| 165 | + DEF_DENSE_TENSOR_IN_AND_OUT(frequency_penalties); |
| 166 | + DEF_DENSE_TENSOR_IN_AND_OUT(presence_penalties); |
| 167 | + DEF_DENSE_TENSOR_IN_AND_OUT(repetition_penalties); |
| 168 | + DEF_DENSE_TENSOR_IN_AND_OUT(min_dec_lens); |
| 169 | + |
| 170 | + auto sampled_ids_t = |
| 171 | + static_cast<const phi::DenseTensor*>(sampled_ids.impl().get()); |
| 172 | + ct.Add(sampled_ids_t); |
| 173 | + |
| 174 | + std::vector<DIMS> inputs_dims = ct.GetDims(); |
| 175 | + std::vector<DIMS> outputs_dims = ct.GetDims(false); |
| 176 | + inputs_dims.insert( |
| 177 | + inputs_dims.end(), outputs_dims.begin(), outputs_dims.end()); |
| 178 | + OpCacheOperator op_info; |
| 179 | + op_info.prepareOpInfo<float, nullptr_t>( |
| 180 | + "FusedIndexSelectKernel", inputs_dims, nullptr); |
| 181 | + auto recipe = op_info.GetRecipe(); |
| 182 | + |
| 183 | + if (recipe == nullptr) { |
| 184 | + custom_kernel::FusedIndexSelectKernel op; |
| 185 | + op.AddNode(ct, batch_size); |
| 186 | + op.Compile(); |
| 187 | + op_info.setOp(op); |
| 188 | + recipe = op_info.GetRecipe(); |
| 189 | + } |
| 190 | + |
| 191 | + std::map<std::string, uint64_t> tensors = ct.GetDeviceAddr(); |
| 192 | + RecipeRunner runner(recipe); |
| 193 | + runner.Run(reinterpret_cast<C_Stream>(dev_ctx->stream()), tensors); |
| 194 | + |
| 195 | + return results; |
| 196 | +} |
| 197 | + |
| 198 | +std::vector<std::vector<int64_t>> FusedIndexSelectInferShape( |
| 199 | + const std::vector<int64_t>& temperature_shape, |
| 200 | + const std::vector<int64_t>& top_p_shape, |
| 201 | + const std::vector<int64_t>& step_index_shape, |
| 202 | + const std::vector<int64_t>& prompt_token_idx_shape, |
| 203 | + const std::vector<int64_t>& pre_token_ids_shape, |
| 204 | + const std::vector<int64_t>& stop_flags_shape, |
| 205 | + const std::vector<int64_t>& seq_lens_encoder_shape, |
| 206 | + const std::vector<int64_t>& seq_lens_decoder_shape, |
| 207 | + const std::vector<int64_t>& frequency_penalties_shape, |
| 208 | + const std::vector<int64_t>& presence_penalties_shape, |
| 209 | + const std::vector<int64_t>& repetition_penalties_shape, |
| 210 | + const std::vector<int64_t>& min_dec_lens_shape, |
| 211 | + const std::vector<int64_t>& sampled_ids_shape) { |
| 212 | + return {temperature_shape}; |
| 213 | +} |
| 214 | + |
| 215 | +std::vector<paddle::DataType> FusedIndexSelectInferDtype( |
| 216 | + const paddle::DataType& temperature_dtype, |
| 217 | + const paddle::DataType& top_p_dtype, |
| 218 | + const paddle::DataType& step_index_dtype, |
| 219 | + const paddle::DataType& prompt_token_idx_dtype, |
| 220 | + const paddle::DataType& pre_token_ids_dtype, |
| 221 | + const paddle::DataType& stop_flags_dtype, |
| 222 | + const paddle::DataType& seq_lens_encoder_dtype, |
| 223 | + const paddle::DataType& seq_lens_decoder_dtype, |
| 224 | + const paddle::DataType& frequency_penalties_dtype, |
| 225 | + const paddle::DataType& presence_penalties_dtype, |
| 226 | + const paddle::DataType& repetition_penalties_dtype, |
| 227 | + const paddle::DataType& min_dec_lens_dtype, |
| 228 | + const paddle::DataType& sampled_ids_dtype) { |
| 229 | + return {temperature_dtype}; |
| 230 | +} |
| 231 | + |
| 232 | +PD_BUILD_OP(fused_index_select) |
| 233 | + .Inputs({"temperature", |
| 234 | + "top_p", |
| 235 | + "step_index", |
| 236 | + "prompt_token_idx", |
| 237 | + "pre_token_ids", |
| 238 | + "stop_flags", |
| 239 | + "seq_lens_encoder", |
| 240 | + "seq_lens_decoder", |
| 241 | + "frequency_penalties", |
| 242 | + "presence_penalties", |
| 243 | + "repetition_penalties", |
| 244 | + "min_dec_lens", |
| 245 | + "sampled_ids"}) |
| 246 | + .Outputs({"temperature_ext", |
| 247 | + "top_p_ext", |
| 248 | + "step_index_ext", |
| 249 | + "prompt_token_idx_ext", |
| 250 | + "pre_token_ids_ext", |
| 251 | + "stop_flags_ext", |
| 252 | + "seq_lens_encoder_ext", |
| 253 | + "seq_lens_decoder_ext", |
| 254 | + "frequency_penalties_ext", |
| 255 | + "presence_penalties_ext", |
| 256 | + "repetition_penalties_ext", |
| 257 | + "min_dec_lens_ext"}) |
| 258 | + .Attrs({"batch_size: int"}) |
| 259 | + .SetKernelFn(PD_KERNEL(FusedIndexSelectForward)) |
| 260 | + .SetInferShapeFn(PD_INFER_SHAPE(FusedIndexSelectInferShape)) |
| 261 | + .SetInferDtypeFn(PD_INFER_DTYPE(FusedIndexSelectInferDtype)); |
0 commit comments