Skip to content

Commit c8ff435

Browse files
authored
Adjust HPU warmup: use dummy inputs with shape more close to real scenario (#689)
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent 2bff275 commit c8ff435

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

backends/src/lib.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ impl Backend {
168168
}
169169
}
170170
for shape in shapes.iter() {
171-
let batch = self.create_warmup_batch(*shape, max_token as u32);
171+
let batch = self.create_warmup_batch(*shape, max_token as u32, seq_bucket_size as u32);
172172
match &self.model_type {
173173
ModelType::Classifier => self.predict(batch).await.map(|_| ()),
174174
ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()),
@@ -179,19 +179,25 @@ impl Backend {
179179
}
180180

181181
#[instrument(skip_all)]
182-
pub fn create_warmup_batch(&self, shape: (u32, u32), max_token: u32) -> Batch {
182+
pub fn create_warmup_batch(&self, shape: (u32, u32), max_token: u32, seq_bucket_size: u32) -> Batch {
183183
let (batch_size, length) = shape;
184+
let min_length = length.saturating_sub(seq_bucket_size).saturating_add(1);
185+
let tmp_length = if min_length < length {
186+
rand::rng().random_range(min_length..length)
187+
} else {
188+
length
189+
};
184190
let mut batched_input_ids = Vec::new();
185191
let mut batched_token_type_ids = Vec::new();
186192
let mut batched_position_ids = Vec::new();
187193
let mut cumulative_seq_lengths = Vec::with_capacity(batch_size as usize + 1);
188194
let mut pooled_indices = Vec::with_capacity(batch_size as usize);
189195
cumulative_seq_lengths.push(0);
190-
let input_ids: Vec<u32> = (0..length)
196+
let input_ids: Vec<u32> = (0..tmp_length)
191197
.map(|_| rand::rng().random_range(0..max_token))
192198
.collect();
193-
let token_type_ids: Vec<u32> = vec![0; length as usize];
194-
let position_ids: Vec<u32> = (0..length).collect();
199+
let token_type_ids: Vec<u32> = vec![0; tmp_length as usize];
200+
let position_ids: Vec<u32> = (0..tmp_length).collect();
195201
let mut current_length = 0;
196202
for batch_id in 0..batch_size {
197203
batched_input_ids.extend(input_ids.iter().cloned());
@@ -206,7 +212,7 @@ impl Backend {
206212
token_type_ids: batched_token_type_ids,
207213
position_ids: batched_position_ids,
208214
cumulative_seq_lengths,
209-
max_length: length,
215+
max_length: tmp_length,
210216
pooled_indices,
211217
raw_indices: vec![],
212218
}

0 commit comments

Comments
 (0)