Skip to content

Commit 1ca54ab

Browse files
authored
[INTEL_HPU] Control bucket size with environment variables (#1870)
1 parent fad434f commit 1ca54ab

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

backends/intel_hpu/custom_ops/llama_infer/prepare_block_metadata.cc

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,26 +165,37 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
165165
auto seq_lens_decoder_cpu =
166166
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
167167

168-
const int batch_step = 4;
169-
const int block_step = 16;
170-
const int max_batches = input_ids.shape()[0];
168+
const char* env_prefill_batch_step = std::getenv("BATCH_STEP_PREFILL");
169+
const int batch_step_prefill =
170+
env_prefill_batch_step ? std::atoi(env_prefill_batch_step) : 1;
171+
const char* env_decode_batch_step = std::getenv("BATCH_STEP_DECODE");
172+
const int batch_step_decode =
173+
env_decode_batch_step ? std::atoi(env_decode_batch_step) : 4;
174+
const char* env_block_step = std::getenv("BLOCK_STEP_DECODE");
175+
const int block_step = env_block_step ? std::atoi(env_block_step) : 16;
176+
const char* env_max_batches = std::getenv("MAX_BATCHES_PREFILL");
177+
const int max_batches_prefill =
178+
env_max_batches ? std::atoi(env_max_batches) : 3;
179+
180+
const int max_batches_in = input_ids.shape()[0];
171181
const int max_seq_len = input_ids.shape()[1];
172182
const int max_blocks_each = block_tables.shape()[1];
173183
phi::DataType device_dtype = phi::StringToDataType(dtype);
174184

175185
auto [max_enc_len, valid_batches_enc] = get_max_and_where_nonzero(
176-
const_cast<int*>(seq_lens_encoder_cpu.data<int>()), max_batches);
186+
const_cast<int*>(seq_lens_encoder_cpu.data<int>()), max_batches_in);
177187
int enc_count = valid_batches_enc.size();
178188

179189
auto valid_batches_dec = where_nonzero(
180-
const_cast<int*>(seq_lens_decoder_cpu.data<int>()), max_batches);
190+
const_cast<int*>(seq_lens_decoder_cpu.data<int>()), max_batches_in);
181191
int dec_count = valid_batches_dec.size();
182192

183193
auto dummy_tensor =
184194
paddle::full({1}, 0, phi::DataType::FLOAT32, paddle::CPUPlace());
185195

186196
if (enc_count > 0) {
187-
int total_batch = find_bucket(enc_count, batch_step, max_batches);
197+
int total_batch =
198+
find_bucket(enc_count, batch_step_prefill, max_batches_prefill);
188199
auto input_ids_cpu = input_ids.copy_to(paddle::CPUPlace(), true);
189200

190201
int max_buckets = (max_enc_len + block_size - 1) / block_size;
@@ -248,7 +259,7 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
248259
is_prompt_cpu_tensor};
249260

250261
} else if (dec_count > 0) {
251-
int total_batch = find_bucket(dec_count, batch_step, max_batches);
262+
int total_batch = find_bucket(dec_count, batch_step_decode, max_batches_in);
252263

253264
auto input_ids_column_0 =
254265
paddle::experimental::slice(input_ids, {1}, {0}, {1}, {}, {});

0 commit comments

Comments
 (0)