Skip to content

Commit 712ea0c

Browse files
authored
[INTEL_HPU] fix update_input_v3 accuracy issue (#1983)
Signed-off-by: Fei Wang <[email protected]>
1 parent ce5c75d commit 712ea0c

File tree

2 files changed

+42
-31
lines changed

2 files changed

+42
-31
lines changed

backends/intel_hpu/custom_ops/llama_infer/update_inputs_v3.cc

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ class UpdateInputsV3Op : public HpuFusedOperator {
6464
synTensor is_block_step = createTensorFromCT(ct, IS_BLOCK_STEP, true);
6565
synTensor end_ids = createTensorFromCT(ct, END_IDS, true);
6666

67+
synTensor input_ids_i32 = createTensorNoPresist(
68+
"input_ids_i32", syn_type_int32, inputs_dims[INPUT_IDS]);
69+
synTensor input_ids_i32_out = createTensorNoPresist(
70+
"input_ids_i32_out", syn_type_int32, inputs_dims[INPUT_IDS]);
71+
6772
synTensor not_need_stop_out =
6873
createTensorFromCT(ct, NOT_NEED_STOP_OUT, false);
6974

@@ -73,6 +78,13 @@ class UpdateInputsV3Op : public HpuFusedOperator {
7378
synTensor input_ids_out =
7479
createTensorFromCT(ct, INPUT_IDS_OUT, false, section_input_ids);
7580

81+
std::vector<synTensor> cast_input_ids_in = {input_ids};
82+
std::vector<synTensor> cast_input_ids_out = {input_ids_i32};
83+
AddNodeCast(cast_input_ids_in,
84+
cast_input_ids_out,
85+
"cast_i64_to_i32",
86+
guid_ + "cast_input_ids_out");
87+
7688
synSectionHandle section_next_tokens = createSection();
7789
synTensor next_tokens =
7890
createTensorFromCT(ct, NEXT_TOKENS, true, section_next_tokens);
@@ -124,12 +136,15 @@ class UpdateInputsV3Op : public HpuFusedOperator {
124136
seq_lens_encoder,
125137
seq_lens_decoder,
126138
max_dec_len,
139+
input_ids_i32,
127140
next_tokens_i32,
128141
is_block_step,
129142
end_ids,
130143
kwargs_next_tokens_i32};
131-
std::vector<synTensor> outs = {
132-
stop_flag_now_int_out, next_tokens_i32_out, kwargs_next_tokens_i32_out};
144+
std::vector<synTensor> outs = {stop_flag_now_int_out,
145+
next_tokens_i32_out,
146+
input_ids_i32_out,
147+
kwargs_next_tokens_i32_out};
133148
std::string node = "custom_update_inputs_v3";
134149
AddNode_IO(ins, outs, node, guid_ + node);
135150

@@ -141,6 +156,14 @@ class UpdateInputsV3Op : public HpuFusedOperator {
141156
"cast_i32_to_i64",
142157
guid_ + "cast_next_tokens_back");
143158

159+
// convert input_ids output back
160+
std::vector<synTensor> cast_input_ids_back_in = {input_ids_i32_out};
161+
std::vector<synTensor> cast_input_ids_back_out = {input_ids_out};
162+
AddNodeCast(cast_input_ids_back_in,
163+
cast_input_ids_back_out,
164+
"cast_i32_to_i64",
165+
guid_ + "cast_input_ids_back");
166+
144167
// convert kwargs_next_tokens output back
145168
std::vector<synTensor> cast_kwargs_next_tokens_back_in = {
146169
kwargs_next_tokens_i32_out};
@@ -151,26 +174,6 @@ class UpdateInputsV3Op : public HpuFusedOperator {
151174
"cast_i32_to_i64",
152175
guid_ + "cast_kwargs_next_tokens_back");
153176

154-
synSliceParams params = {{0}};
155-
for (size_t i = 0; i < inputs_dims[INPUT_IDS].size(); i++) {
156-
params.axes[i] = i;
157-
params.steps[i] = 1;
158-
params.starts[i] = 0;
159-
params.ends[i] =
160-
inputs_dims[INPUT_IDS][inputs_dims[INPUT_IDS].size() - 1 - i];
161-
}
162-
params.starts[inputs_dims[INPUT_IDS].size() - 1 - 1] = 0;
163-
params.ends[inputs_dims[INPUT_IDS].size() - 1 - 1] = 1;
164-
165-
std::vector<synTensor> set_value_in = {input_ids, next_tokens_out};
166-
std::vector<synTensor> set_value_out = {input_ids_out};
167-
168-
AddNode_IOP<synSliceParams>(set_value_in,
169-
set_value_out,
170-
params,
171-
"slice_insert",
172-
guid_ + "slice_insert");
173-
174177
ns_Reduction::Params reduce_params = {0};
175178
std::vector<synTensor> reduce_in = {stop_flag_now_int_out};
176179
std::vector<synTensor> reduce_out = {stop_sum};
@@ -222,15 +225,13 @@ void update_inputs_v3(const paddle::Tensor& stop_flags,
222225
INSERT_TENSOR_TO_CT(is_block_step, ct, IS_BLOCK_STEP, true);
223226
INSERT_TENSOR_TO_CT(end_ids, ct, END_IDS, true);
224227
INSERT_TENSOR_TO_CT(kwargs_next_tokens, ct, KWARGS_NEXT_TOKENS, true);
228+
INSERT_TENSOR_TO_CT(next_tokens_i32, ct, NEXT_TOKENS_I32, true);
229+
INSERT_TENSOR_TO_CT(kwargs_next_tokens_i32, ct, KWARGS_NEXT_TOKENS_I32, true);
225230

226231
INSERT_TENSOR_TO_CT(not_need_stop, ct, NOT_NEED_STOP_OUT, false);
227232
INSERT_TENSOR_TO_CT(input_ids, ct, INPUT_IDS_OUT, false);
228233
INSERT_TENSOR_TO_CT(next_tokens, ct, NEXT_TOKENS_OUT, false);
229234
INSERT_TENSOR_TO_CT(kwargs_next_tokens, ct, KWARGS_NEXT_TOKENS_OUT, false);
230-
231-
INSERT_TENSOR_TO_CT(next_tokens_i32, ct, NEXT_TOKENS_I32, true);
232-
INSERT_TENSOR_TO_CT(kwargs_next_tokens_i32, ct, KWARGS_NEXT_TOKENS_I32, true);
233-
234235
INSERT_TENSOR_TO_CT(next_tokens_i32, ct, NEXT_TOKENS_OUT_I32, false);
235236
INSERT_TENSOR_TO_CT(
236237
kwargs_next_tokens_i32, ct, KWARGS_NEXT_TOKENS_OUT_I32, false);
@@ -347,8 +348,12 @@ PD_BUILD_OP(update_inputs_v3)
347348
"is_block_step",
348349
"end_ids",
349350
"kwargs_next_tokens"})
350-
.Outputs({"not_need_stop_out", "next_tokens_out", "kwargs_next_tokens_out"})
351+
.Outputs({"not_need_stop_out",
352+
"input_ids_out",
353+
"next_tokens_out",
354+
"kwargs_next_tokens_out"})
351355
.SetInplaceMap({{"not_need_stop", "not_need_stop_out"},
356+
{"input_ids", "input_ids_out"},
352357
{"next_tokens", "next_tokens_out"},
353358
{"kwargs_next_tokens", "kwargs_next_tokens_out"}})
354359
.SetKernelFn(PD_KERNEL(UpdateInputsV3))

backends/intel_hpu/kernels/set_value_kernel.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class SetTensorValueExp : public HpuOperator {
118118
template <typename T, typename Context>
119119
void SetTensorValueKernel(const Context& dev_ctx,
120120
const phi::DenseTensor& x,
121-
const phi::DenseTensor& value,
121+
const phi::DenseTensor& val,
122122
const phi::IntArray& starts,
123123
const phi::IntArray& ends,
124124
const phi::IntArray& steps,
@@ -130,8 +130,14 @@ void SetTensorValueKernel(const Context& dev_ctx,
130130
auto starts_v = starts.GetData();
131131
auto ends_v = ends.GetData();
132132

133+
paddle::Tensor val_tensor(std::make_shared<phi::DenseTensor>(val));
134+
auto value_tensor =
135+
custom_kernel::copy_tensor_wrapper(&dev_ctx, val_tensor, val.place());
136+
137+
auto value = static_cast<phi::DenseTensor*>(value_tensor.impl().get());
138+
133139
const auto& x_dims = x.dims();
134-
const auto& value_dims = value.dims();
140+
const auto& value_dims = value->dims();
135141

136142
PADDLE_ENFORCE_EQ(
137143
starts_v.size(),
@@ -181,7 +187,7 @@ void SetTensorValueKernel(const Context& dev_ctx,
181187
synRecipeHandle recipe;
182188
if (v_sum != v_new_sum) {
183189
std::vector<int64_t> input_dim = phi::vectorize<int64_t>(x_dims);
184-
std::vector<int64_t> value_dim = phi::vectorize<int64_t>(value.dims());
190+
std::vector<int64_t> value_dim = phi::vectorize<int64_t>(value->dims());
185191
std::vector<int64_t> outputs_dim = phi::vectorize<int64_t>(out->dims());
186192
std::vector<int64_t> value_new_dim = phi::vectorize<int64_t>(new_dims);
187193

@@ -226,7 +232,7 @@ void SetTensorValueKernel(const Context& dev_ctx,
226232
// runtime
227233
std::map<std::string, uint64_t> tensors;
228234
tensors["input"] = reinterpret_cast<uint64_t>(x.data<T>());
229-
tensors["value"] = reinterpret_cast<uint64_t>(value.data<T>());
235+
tensors["value"] = reinterpret_cast<uint64_t>(value->data<T>());
230236
tensors["output"] = reinterpret_cast<uint64_t>(out->data<T>());
231237

232238
RecipeRunner runner(recipe);

0 commit comments

Comments
 (0)