Skip to content

Commit a4c2f9e

Browse files
authored
[INTEL_HPU] Refine update_inputs_v2 OP. (#1867)
Signed-off-by: Fei Wang <[email protected]>
1 parent ba7e2b7 commit a4c2f9e

File tree

1 file changed

+37
-13
lines changed

1 file changed

+37
-13
lines changed

backends/intel_hpu/custom_ops/llama_infer/update_inputs_v2.cc

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@
2121
#include "paddle/extension.h"
2222
#include "utils/utils.h"
2323

24+
namespace custom_kernel {
25+
template <typename T, typename Context>
26+
void SetTensorValueKernel(const Context& dev_ctx,
27+
const phi::DenseTensor& x,
28+
const phi::DenseTensor& value,
29+
const phi::IntArray& starts,
30+
const phi::IntArray& ends,
31+
const phi::IntArray& steps,
32+
const std::vector<int64_t>& axes,
33+
const std::vector<int64_t>& decrease_axes,
34+
const std::vector<int64_t>& none_axes,
35+
phi::DenseTensor* out);
36+
} // namespace custom_kernel
37+
2438
bool is_in_end_v3(const int64_t id, const int64_t* end_ids, int length) {
2539
for (int i = 0; i < length; i++) {
2640
if (id == end_ids[i]) {
@@ -38,14 +52,12 @@ void cpu_wrapper(bool* not_need_stop,
3852
int* seq_lens_decoder,
3953
int64_t* next_tokens,
4054
int64_t* kwargs_next_tokens,
41-
int64_t* input_ids,
4255
const int64_t* end_ids,
4356
const int64_t* stop_nums,
4457
const bool* is_block_step,
4558
const int64_t* max_dec_len,
4659
int bsz,
4760
int max_bsz,
48-
int input_ids_stride,
4961
int end_length) {
5062
#pragma omp parallel for num_threads(OMP_THREAD_NUM)
5163
for (int i = 0; i < max_bsz; i++) {
@@ -91,8 +103,6 @@ void cpu_wrapper(bool* not_need_stop,
91103

92104
seq_lens_this_time[i] = stop_flags[i] ? 0 : 1;
93105
seq_lens_encoder[i] = 0;
94-
int64_t* input_ids_now = input_ids + i * input_ids_stride;
95-
input_ids_now[0] = next_tokens[i];
96106
}
97107
int64_t stop_sum = 0;
98108
for (size_t i = 0; i < stop_flag_now_int.size(); i++) {
@@ -110,14 +120,12 @@ void update_inputs_v2(bool* not_need_stop,
110120
int* seq_lens_decoder,
111121
int64_t* next_tokens,
112122
int64_t* kwargs_next_tokens,
113-
int64_t* input_ids,
114123
const int64_t* end_ids,
115124
const int64_t* stop_nums,
116125
const bool* is_block_step,
117126
const int64_t* max_dec_len,
118127
int now_bsz,
119128
int max_bsz,
120-
int input_ids_stride,
121129
int end_length) {
122130
PD_CHECK(max_bsz <= 1024,
123131
"Max supported batch size is 1024. Now received ",
@@ -132,14 +140,12 @@ void update_inputs_v2(bool* not_need_stop,
132140
seq_lens_decoder,
133141
next_tokens,
134142
kwargs_next_tokens,
135-
input_ids,
136143
end_ids,
137144
stop_nums,
138145
is_block_step,
139146
max_dec_len,
140147
now_bsz,
141148
max_bsz,
142-
input_ids_stride,
143149
end_length);
144150
}
145151

@@ -166,7 +172,6 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
166172
auto seq_lens_decoder_cpu =
167173
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
168174
auto max_dec_len_cpu = max_dec_len.copy_to(paddle::CPUPlace(), true);
169-
auto input_ids_cpu = input_ids.copy_to(paddle::CPUPlace(), true);
170175
auto stop_nums_cpu = stop_nums.copy_to(paddle::CPUPlace(), true);
171176
auto next_tokens_cpu = next_tokens.copy_to(paddle::CPUPlace(), true);
172177
auto is_block_step_cpu = is_block_step.copy_to(paddle::CPUPlace(), true);
@@ -180,7 +185,6 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
180185

181186
const int max_bsz = stop_flags.shape()[0];
182187
const int now_bsz = seq_lens_this_time.shape()[0];
183-
const int input_ids_stride = input_ids.shape()[1];
184188
const int end_length = end_ids.shape()[0];
185189
update_inputs_v2(const_cast<bool*>(not_need_stop_cpu.data<bool>()),
186190
const_cast<int64_t*>(step_idx_cpu.data<int64_t>()),
@@ -190,14 +194,12 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
190194
const_cast<int*>(seq_lens_decoder_cpu.data<int>()),
191195
const_cast<int64_t*>(next_tokens_cpu.data<int64_t>()),
192196
const_cast<int64_t*>(kwargs_next_tokens_cpu.data<int64_t>()),
193-
const_cast<int64_t*>(input_ids_cpu.data<int64_t>()),
194197
end_ids_cpu.data<int64_t>(),
195198
stop_nums_cpu.data<int64_t>(),
196199
is_block_step_cpu.data<bool>(),
197200
max_dec_len_cpu.data<int64_t>(),
198201
now_bsz,
199202
max_bsz,
200-
input_ids_stride,
201203
end_length);
202204

203205
custom_kernel::copy_tensor_wrapper(dev_ctx, not_need_stop_cpu, not_need_stop);
@@ -209,10 +211,32 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
209211
dev_ctx, seq_lens_encoder_cpu, seq_lens_encoder);
210212
custom_kernel::copy_tensor_wrapper(
211213
dev_ctx, seq_lens_decoder_cpu, seq_lens_decoder);
212-
custom_kernel::copy_tensor_wrapper(dev_ctx, input_ids_cpu, input_ids);
213214
custom_kernel::copy_tensor_wrapper(dev_ctx, next_tokens_cpu, next_tokens);
214215
custom_kernel::copy_tensor_wrapper(
215216
dev_ctx, kwargs_next_tokens_cpu, kwargs_next_tokens);
217+
218+
auto input_ids_tensor =
219+
static_cast<phi::DenseTensor*>(input_ids.impl().get());
220+
auto next_tokens_tensor =
221+
static_cast<const phi::DenseTensor*>(next_tokens.impl().get());
222+
auto starts = phi::IntArray(std::vector<int64_t>{0});
223+
auto ends = phi::IntArray(std::vector<int64_t>{1});
224+
auto steps = phi::IntArray(std::vector<int64_t>{1});
225+
std::vector<int64_t> axes = {1};
226+
std::vector<int64_t> decrease_axes;
227+
std::vector<int64_t> none_axes;
228+
229+
custom_kernel::SetTensorValueKernel<int64_t, phi::CustomContext>(
230+
*dev_ctx,
231+
*input_ids_tensor,
232+
*next_tokens_tensor,
233+
starts,
234+
ends,
235+
steps,
236+
axes,
237+
decrease_axes,
238+
none_axes,
239+
input_ids_tensor);
216240
}
217241

218242
std::vector<std::vector<int64_t>> UpdateInputsV2InferShape(

0 commit comments

Comments
 (0)