@@ -168,7 +168,7 @@ impl Backend {
168
168
}
169
169
}
170
170
for shape in shapes. iter ( ) {
171
- let batch = self . create_warmup_batch ( * shape, max_token as u32 ) ;
171
+ let batch = self . create_warmup_batch ( * shape, max_token as u32 , seq_bucket_size as u32 ) ;
172
172
match & self . model_type {
173
173
ModelType :: Classifier => self . predict ( batch) . await . map ( |_| ( ) ) ,
174
174
ModelType :: Embedding ( _) => self . embed ( batch) . await . map ( |_| ( ) ) ,
@@ -179,19 +179,25 @@ impl Backend {
179
179
}
180
180
181
181
#[ instrument( skip_all) ]
182
- pub fn create_warmup_batch ( & self , shape : ( u32 , u32 ) , max_token : u32 ) -> Batch {
182
+ pub fn create_warmup_batch ( & self , shape : ( u32 , u32 ) , max_token : u32 , seq_bucket_size : u32 ) -> Batch {
183
183
let ( batch_size, length) = shape;
184
+ let min_length = length. saturating_sub ( seq_bucket_size) . saturating_add ( 1 ) ;
185
+ let tmp_length = if min_length < length {
186
+ rand:: rng ( ) . random_range ( min_length..length)
187
+ } else {
188
+ length
189
+ } ;
184
190
let mut batched_input_ids = Vec :: new ( ) ;
185
191
let mut batched_token_type_ids = Vec :: new ( ) ;
186
192
let mut batched_position_ids = Vec :: new ( ) ;
187
193
let mut cumulative_seq_lengths = Vec :: with_capacity ( batch_size as usize + 1 ) ;
188
194
let mut pooled_indices = Vec :: with_capacity ( batch_size as usize ) ;
189
195
cumulative_seq_lengths. push ( 0 ) ;
190
- let input_ids: Vec < u32 > = ( 0 ..length )
196
+ let input_ids: Vec < u32 > = ( 0 ..tmp_length )
191
197
. map ( |_| rand:: rng ( ) . random_range ( 0 ..max_token) )
192
198
. collect ( ) ;
193
- let token_type_ids: Vec < u32 > = vec ! [ 0 ; length as usize ] ;
194
- let position_ids: Vec < u32 > = ( 0 ..length ) . collect ( ) ;
199
+ let token_type_ids: Vec < u32 > = vec ! [ 0 ; tmp_length as usize ] ;
200
+ let position_ids: Vec < u32 > = ( 0 ..tmp_length ) . collect ( ) ;
195
201
let mut current_length = 0 ;
196
202
for batch_id in 0 ..batch_size {
197
203
batched_input_ids. extend ( input_ids. iter ( ) . cloned ( ) ) ;
@@ -206,7 +212,7 @@ impl Backend {
206
212
token_type_ids : batched_token_type_ids,
207
213
position_ids : batched_position_ids,
208
214
cumulative_seq_lengths,
209
- max_length : length ,
215
+ max_length : tmp_length ,
210
216
pooled_indices,
211
217
raw_indices : vec ! [ ] ,
212
218
}
0 commit comments