@@ -64,6 +64,11 @@ class UpdateInputsV3Op : public HpuFusedOperator {
64
64
synTensor is_block_step = createTensorFromCT (ct, IS_BLOCK_STEP, true );
65
65
synTensor end_ids = createTensorFromCT (ct, END_IDS, true );
66
66
67
+ synTensor input_ids_i32 = createTensorNoPresist (
68
+ " input_ids_i32" , syn_type_int32, inputs_dims[INPUT_IDS]);
69
+ synTensor input_ids_i32_out = createTensorNoPresist (
70
+ " input_ids_i32_out" , syn_type_int32, inputs_dims[INPUT_IDS]);
71
+
67
72
synTensor not_need_stop_out =
68
73
createTensorFromCT (ct, NOT_NEED_STOP_OUT, false );
69
74
@@ -73,6 +78,13 @@ class UpdateInputsV3Op : public HpuFusedOperator {
73
78
synTensor input_ids_out =
74
79
createTensorFromCT (ct, INPUT_IDS_OUT, false , section_input_ids);
75
80
81
+ std::vector<synTensor> cast_input_ids_in = {input_ids};
82
+ std::vector<synTensor> cast_input_ids_out = {input_ids_i32};
83
+ AddNodeCast (cast_input_ids_in,
84
+ cast_input_ids_out,
85
+ " cast_i64_to_i32" ,
86
+ guid_ + " cast_input_ids_out" );
87
+
76
88
synSectionHandle section_next_tokens = createSection ();
77
89
synTensor next_tokens =
78
90
createTensorFromCT (ct, NEXT_TOKENS, true , section_next_tokens);
@@ -124,12 +136,15 @@ class UpdateInputsV3Op : public HpuFusedOperator {
124
136
seq_lens_encoder,
125
137
seq_lens_decoder,
126
138
max_dec_len,
139
+ input_ids_i32,
127
140
next_tokens_i32,
128
141
is_block_step,
129
142
end_ids,
130
143
kwargs_next_tokens_i32};
131
- std::vector<synTensor> outs = {
132
- stop_flag_now_int_out, next_tokens_i32_out, kwargs_next_tokens_i32_out};
144
+ std::vector<synTensor> outs = {stop_flag_now_int_out,
145
+ next_tokens_i32_out,
146
+ input_ids_i32_out,
147
+ kwargs_next_tokens_i32_out};
133
148
std::string node = " custom_update_inputs_v3" ;
134
149
AddNode_IO (ins, outs, node, guid_ + node);
135
150
@@ -141,6 +156,14 @@ class UpdateInputsV3Op : public HpuFusedOperator {
141
156
" cast_i32_to_i64" ,
142
157
guid_ + " cast_next_tokens_back" );
143
158
159
+ // convert input_ids output back
160
+ std::vector<synTensor> cast_input_ids_back_in = {input_ids_i32_out};
161
+ std::vector<synTensor> cast_input_ids_back_out = {input_ids_out};
162
+ AddNodeCast (cast_input_ids_back_in,
163
+ cast_input_ids_back_out,
164
+ " cast_i32_to_i64" ,
165
+ guid_ + " cast_input_ids_back" );
166
+
144
167
// convert kwargs_next_tokens output back
145
168
std::vector<synTensor> cast_kwargs_next_tokens_back_in = {
146
169
kwargs_next_tokens_i32_out};
@@ -151,26 +174,6 @@ class UpdateInputsV3Op : public HpuFusedOperator {
151
174
" cast_i32_to_i64" ,
152
175
guid_ + " cast_kwargs_next_tokens_back" );
153
176
154
- synSliceParams params = {{0 }};
155
- for (size_t i = 0 ; i < inputs_dims[INPUT_IDS].size (); i++) {
156
- params.axes [i] = i;
157
- params.steps [i] = 1 ;
158
- params.starts [i] = 0 ;
159
- params.ends [i] =
160
- inputs_dims[INPUT_IDS][inputs_dims[INPUT_IDS].size () - 1 - i];
161
- }
162
- params.starts [inputs_dims[INPUT_IDS].size () - 1 - 1 ] = 0 ;
163
- params.ends [inputs_dims[INPUT_IDS].size () - 1 - 1 ] = 1 ;
164
-
165
- std::vector<synTensor> set_value_in = {input_ids, next_tokens_out};
166
- std::vector<synTensor> set_value_out = {input_ids_out};
167
-
168
- AddNode_IOP<synSliceParams>(set_value_in,
169
- set_value_out,
170
- params,
171
- " slice_insert" ,
172
- guid_ + " slice_insert" );
173
-
174
177
ns_Reduction::Params reduce_params = {0 };
175
178
std::vector<synTensor> reduce_in = {stop_flag_now_int_out};
176
179
std::vector<synTensor> reduce_out = {stop_sum};
@@ -222,15 +225,13 @@ void update_inputs_v3(const paddle::Tensor& stop_flags,
222
225
INSERT_TENSOR_TO_CT (is_block_step, ct, IS_BLOCK_STEP, true );
223
226
INSERT_TENSOR_TO_CT (end_ids, ct, END_IDS, true );
224
227
INSERT_TENSOR_TO_CT (kwargs_next_tokens, ct, KWARGS_NEXT_TOKENS, true );
228
+ INSERT_TENSOR_TO_CT (next_tokens_i32, ct, NEXT_TOKENS_I32, true );
229
+ INSERT_TENSOR_TO_CT (kwargs_next_tokens_i32, ct, KWARGS_NEXT_TOKENS_I32, true );
225
230
226
231
INSERT_TENSOR_TO_CT (not_need_stop, ct, NOT_NEED_STOP_OUT, false );
227
232
INSERT_TENSOR_TO_CT (input_ids, ct, INPUT_IDS_OUT, false );
228
233
INSERT_TENSOR_TO_CT (next_tokens, ct, NEXT_TOKENS_OUT, false );
229
234
INSERT_TENSOR_TO_CT (kwargs_next_tokens, ct, KWARGS_NEXT_TOKENS_OUT, false );
230
-
231
- INSERT_TENSOR_TO_CT (next_tokens_i32, ct, NEXT_TOKENS_I32, true );
232
- INSERT_TENSOR_TO_CT (kwargs_next_tokens_i32, ct, KWARGS_NEXT_TOKENS_I32, true );
233
-
234
235
INSERT_TENSOR_TO_CT (next_tokens_i32, ct, NEXT_TOKENS_OUT_I32, false );
235
236
INSERT_TENSOR_TO_CT (
236
237
kwargs_next_tokens_i32, ct, KWARGS_NEXT_TOKENS_OUT_I32, false );
@@ -347,8 +348,12 @@ PD_BUILD_OP(update_inputs_v3)
347
348
" is_block_step" ,
348
349
" end_ids" ,
349
350
" kwargs_next_tokens" })
350
- .Outputs({" not_need_stop_out" , " next_tokens_out" , " kwargs_next_tokens_out" })
351
+ .Outputs({" not_need_stop_out" ,
352
+ " input_ids_out" ,
353
+ " next_tokens_out" ,
354
+ " kwargs_next_tokens_out" })
351
355
.SetInplaceMap({{" not_need_stop" , " not_need_stop_out" },
356
+ {" input_ids" , " input_ids_out" },
352
357
{" next_tokens" , " next_tokens_out" },
353
358
{" kwargs_next_tokens" , " kwargs_next_tokens_out" }})
354
359
.SetKernelFn(PD_KERNEL(UpdateInputsV3))
0 commit comments