@@ -165,26 +165,37 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
165
165
auto seq_lens_decoder_cpu =
166
166
seq_lens_decoder.copy_to (paddle::CPUPlace (), true );
167
167
168
- const int batch_step = 4 ;
169
- const int block_step = 16 ;
170
- const int max_batches = input_ids.shape ()[0 ];
168
+ const char * env_prefill_batch_step = std::getenv (" BATCH_STEP_PREFILL" );
169
+ const int batch_step_prefill =
170
+ env_prefill_batch_step ? std::atoi (env_prefill_batch_step) : 1 ;
171
+ const char * env_decode_batch_step = std::getenv (" BATCH_STEP_DECODE" );
172
+ const int batch_step_decode =
173
+ env_decode_batch_step ? std::atoi (env_decode_batch_step) : 4 ;
174
+ const char * env_block_step = std::getenv (" BLOCK_STEP_DECODE" );
175
+ const int block_step = env_block_step ? std::atoi (env_block_step) : 16 ;
176
+ const char * env_max_batches = std::getenv (" MAX_BATCHES_PREFILL" );
177
+ const int max_batches_prefill =
178
+ env_max_batches ? std::atoi (env_max_batches) : 3 ;
179
+
180
+ const int max_batches_in = input_ids.shape ()[0 ];
171
181
const int max_seq_len = input_ids.shape ()[1 ];
172
182
const int max_blocks_each = block_tables.shape ()[1 ];
173
183
phi::DataType device_dtype = phi::StringToDataType (dtype);
174
184
175
185
auto [max_enc_len, valid_batches_enc] = get_max_and_where_nonzero (
176
- const_cast <int *>(seq_lens_encoder_cpu.data <int >()), max_batches );
186
+ const_cast <int *>(seq_lens_encoder_cpu.data <int >()), max_batches_in );
177
187
int enc_count = valid_batches_enc.size ();
178
188
179
189
auto valid_batches_dec = where_nonzero (
180
- const_cast <int *>(seq_lens_decoder_cpu.data <int >()), max_batches );
190
+ const_cast <int *>(seq_lens_decoder_cpu.data <int >()), max_batches_in );
181
191
int dec_count = valid_batches_dec.size ();
182
192
183
193
auto dummy_tensor =
184
194
paddle::full ({1 }, 0 , phi::DataType::FLOAT32, paddle::CPUPlace ());
185
195
186
196
if (enc_count > 0 ) {
187
- int total_batch = find_bucket (enc_count, batch_step, max_batches);
197
+ int total_batch =
198
+ find_bucket (enc_count, batch_step_prefill, max_batches_prefill);
188
199
auto input_ids_cpu = input_ids.copy_to (paddle::CPUPlace (), true );
189
200
190
201
int max_buckets = (max_enc_len + block_size - 1 ) / block_size;
@@ -248,7 +259,7 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
248
259
is_prompt_cpu_tensor};
249
260
250
261
} else if (dec_count > 0 ) {
251
- int total_batch = find_bucket (dec_count, batch_step, max_batches );
262
+ int total_batch = find_bucket (dec_count, batch_step_decode, max_batches_in );
252
263
253
264
auto input_ids_column_0 =
254
265
paddle::experimental::slice (input_ids, {1 }, {0 }, {1 }, {}, {});
0 commit comments