21
21
#include " paddle/extension.h"
22
22
#include " utils/utils.h"
23
23
24
+ namespace custom_kernel {
25
+ template <typename T, typename Context>
26
+ void SetTensorValueKernel (const Context& dev_ctx,
27
+ const phi::DenseTensor& x,
28
+ const phi::DenseTensor& value,
29
+ const phi::IntArray& starts,
30
+ const phi::IntArray& ends,
31
+ const phi::IntArray& steps,
32
+ const std::vector<int64_t >& axes,
33
+ const std::vector<int64_t >& decrease_axes,
34
+ const std::vector<int64_t >& none_axes,
35
+ phi::DenseTensor* out);
36
+ } // namespace custom_kernel
37
+
24
38
bool is_in_end_v3 (const int64_t id, const int64_t * end_ids, int length) {
25
39
for (int i = 0 ; i < length; i++) {
26
40
if (id == end_ids[i]) {
@@ -38,14 +52,12 @@ void cpu_wrapper(bool* not_need_stop,
38
52
int * seq_lens_decoder,
39
53
int64_t * next_tokens,
40
54
int64_t * kwargs_next_tokens,
41
- int64_t * input_ids,
42
55
const int64_t * end_ids,
43
56
const int64_t * stop_nums,
44
57
const bool * is_block_step,
45
58
const int64_t * max_dec_len,
46
59
int bsz,
47
60
int max_bsz,
48
- int input_ids_stride,
49
61
int end_length) {
50
62
#pragma omp parallel for num_threads(OMP_THREAD_NUM)
51
63
for (int i = 0 ; i < max_bsz; i++) {
@@ -91,8 +103,6 @@ void cpu_wrapper(bool* not_need_stop,
91
103
92
104
seq_lens_this_time[i] = stop_flags[i] ? 0 : 1 ;
93
105
seq_lens_encoder[i] = 0 ;
94
- int64_t * input_ids_now = input_ids + i * input_ids_stride;
95
- input_ids_now[0 ] = next_tokens[i];
96
106
}
97
107
int64_t stop_sum = 0 ;
98
108
for (size_t i = 0 ; i < stop_flag_now_int.size (); i++) {
@@ -110,14 +120,12 @@ void update_inputs_v2(bool* not_need_stop,
110
120
int * seq_lens_decoder,
111
121
int64_t * next_tokens,
112
122
int64_t * kwargs_next_tokens,
113
- int64_t * input_ids,
114
123
const int64_t * end_ids,
115
124
const int64_t * stop_nums,
116
125
const bool * is_block_step,
117
126
const int64_t * max_dec_len,
118
127
int now_bsz,
119
128
int max_bsz,
120
- int input_ids_stride,
121
129
int end_length) {
122
130
PD_CHECK (max_bsz <= 1024 ,
123
131
" Max supported batch size is 1024. Now received " ,
@@ -132,14 +140,12 @@ void update_inputs_v2(bool* not_need_stop,
132
140
seq_lens_decoder,
133
141
next_tokens,
134
142
kwargs_next_tokens,
135
- input_ids,
136
143
end_ids,
137
144
stop_nums,
138
145
is_block_step,
139
146
max_dec_len,
140
147
now_bsz,
141
148
max_bsz,
142
- input_ids_stride,
143
149
end_length);
144
150
}
145
151
@@ -166,7 +172,6 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
166
172
auto seq_lens_decoder_cpu =
167
173
seq_lens_decoder.copy_to (paddle::CPUPlace (), true );
168
174
auto max_dec_len_cpu = max_dec_len.copy_to (paddle::CPUPlace (), true );
169
- auto input_ids_cpu = input_ids.copy_to (paddle::CPUPlace (), true );
170
175
auto stop_nums_cpu = stop_nums.copy_to (paddle::CPUPlace (), true );
171
176
auto next_tokens_cpu = next_tokens.copy_to (paddle::CPUPlace (), true );
172
177
auto is_block_step_cpu = is_block_step.copy_to (paddle::CPUPlace (), true );
@@ -180,7 +185,6 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
180
185
181
186
const int max_bsz = stop_flags.shape ()[0 ];
182
187
const int now_bsz = seq_lens_this_time.shape ()[0 ];
183
- const int input_ids_stride = input_ids.shape ()[1 ];
184
188
const int end_length = end_ids.shape ()[0 ];
185
189
update_inputs_v2 (const_cast <bool *>(not_need_stop_cpu.data <bool >()),
186
190
const_cast <int64_t *>(step_idx_cpu.data <int64_t >()),
@@ -190,14 +194,12 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
190
194
const_cast <int *>(seq_lens_decoder_cpu.data <int >()),
191
195
const_cast <int64_t *>(next_tokens_cpu.data <int64_t >()),
192
196
const_cast <int64_t *>(kwargs_next_tokens_cpu.data <int64_t >()),
193
- const_cast <int64_t *>(input_ids_cpu.data <int64_t >()),
194
197
end_ids_cpu.data <int64_t >(),
195
198
stop_nums_cpu.data <int64_t >(),
196
199
is_block_step_cpu.data <bool >(),
197
200
max_dec_len_cpu.data <int64_t >(),
198
201
now_bsz,
199
202
max_bsz,
200
- input_ids_stride,
201
203
end_length);
202
204
203
205
custom_kernel::copy_tensor_wrapper (dev_ctx, not_need_stop_cpu, not_need_stop);
@@ -209,10 +211,32 @@ void UpdateInputesV2(const paddle::Tensor& stop_flags,
209
211
dev_ctx, seq_lens_encoder_cpu, seq_lens_encoder);
210
212
custom_kernel::copy_tensor_wrapper (
211
213
dev_ctx, seq_lens_decoder_cpu, seq_lens_decoder);
212
- custom_kernel::copy_tensor_wrapper (dev_ctx, input_ids_cpu, input_ids);
213
214
custom_kernel::copy_tensor_wrapper (dev_ctx, next_tokens_cpu, next_tokens);
214
215
custom_kernel::copy_tensor_wrapper (
215
216
dev_ctx, kwargs_next_tokens_cpu, kwargs_next_tokens);
217
+
218
+ auto input_ids_tensor =
219
+ static_cast <phi::DenseTensor*>(input_ids.impl ().get ());
220
+ auto next_tokens_tensor =
221
+ static_cast <const phi::DenseTensor*>(next_tokens.impl ().get ());
222
+ auto starts = phi::IntArray (std::vector<int64_t >{0 });
223
+ auto ends = phi::IntArray (std::vector<int64_t >{1 });
224
+ auto steps = phi::IntArray (std::vector<int64_t >{1 });
225
+ std::vector<int64_t > axes = {1 };
226
+ std::vector<int64_t > decrease_axes;
227
+ std::vector<int64_t > none_axes;
228
+
229
+ custom_kernel::SetTensorValueKernel<int64_t , phi::CustomContext>(
230
+ *dev_ctx,
231
+ *input_ids_tensor,
232
+ *next_tokens_tensor,
233
+ starts,
234
+ ends,
235
+ steps,
236
+ axes,
237
+ decrease_axes,
238
+ none_axes,
239
+ input_ids_tensor);
216
240
}
217
241
218
242
std::vector<std::vector<int64_t >> UpdateInputsV2InferShape (
0 commit comments