Skip to content

Commit 9919ae1

Browse files
authored
Parallelize tokenization for /classify_batch and remove block allocator for non-causal LMs (#609)
1 parent a6b60e9 commit 9919ae1

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

router/src/infer.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ impl Infer {
161161
speculate: u32,
162162
preloaded_adapters: Vec<PreloadedAdapter>,
163163
prefix_caching: bool,
164+
is_causal_lm: bool,
164165
) -> Self {
165166
let adapter_event = Arc::new(AdapterEvent {
166167
batching_task: Notify::new(),
@@ -178,6 +179,7 @@ impl Infer {
178179
speculate,
179180
max_batch_total_tokens,
180181
prefix_caching,
182+
is_causal_lm,
181183
);
182184

183185
// Initialize with base model adapter (empty) mapping to index 0
@@ -729,13 +731,19 @@ impl Infer {
729731
.map(|(id, input)| (id as u64, input.clone()))
730732
.collect();
731733

732-
for (id, r_inputs) in request.inputs.iter().enumerate() {
733-
let inputs = r_inputs.to_string().clone();
734-
let (tokenized_inputs, input_length) = self
735-
.validation
736-
.validate_input(r_inputs.to_string(), None, Some(1))
737-
.await?;
734+
// Call validate_input on every input in the request and await the results
735+
let futures: Vec<_> = request
736+
.inputs
737+
.iter()
738+
.map(|input| self.validation.validate_input(input.clone(), None, Some(1)))
739+
.collect();
738740

741+
let all_tokenized_inputs = try_join_all(futures).await?;
742+
743+
for ((id, r_inputs), (tokenized_inputs, input_length)) in
744+
request.inputs.iter().enumerate().zip(all_tokenized_inputs)
745+
{
746+
let inputs = r_inputs.to_string().clone();
739747
let valid_request = ValidClassifyRequest {
740748
inputs,
741749
tokenized_inputs,

router/src/scheduler.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ impl AdapterScheduler {
4444
speculate: u32,
4545
max_batch_total_tokens: u32,
4646
prefix_caching: bool,
47+
is_causal_lm: bool,
4748
) -> Self {
4849
let (sender, receiver) = flume::unbounded();
4950

@@ -60,6 +61,7 @@ impl AdapterScheduler {
6061
speculate,
6162
max_batch_total_tokens,
6263
prefix_caching,
64+
is_causal_lm,
6365
));
6466

6567
Self { sender }
@@ -124,6 +126,7 @@ async fn adapter_scheduler_task(
124126
speculate: u32,
125127
max_batch_total_tokens: u32,
126128
prefix_caching: bool,
129+
is_causal_lm: bool,
127130
) {
128131
let mut state = AdapterSchedulerState::new(
129132
client,
@@ -135,6 +138,7 @@ async fn adapter_scheduler_task(
135138
speculate,
136139
max_batch_total_tokens,
137140
prefix_caching,
141+
is_causal_lm,
138142
);
139143

140144
while let Ok(cmd) = receiver.recv_async().await {
@@ -209,14 +213,16 @@ impl AdapterSchedulerState {
209213
speculate: u32,
210214
max_batch_total_tokens: u32,
211215
prefix_caching: bool,
216+
is_causal_lm: bool,
212217
) -> Self {
213218
let queues_state = Arc::new(Mutex::new(AdapterQueuesState::new(
214219
max_active_adapters,
215220
adapter_cycle_time_s,
216221
)));
217222
let loader = AdapterLoader::new(client.clone());
218223

219-
let block_allocator = (!requires_padding).then(|| {
224+
// Only causal LMs require the block allocator, due to paged attention
225+
let block_allocator = (!requires_padding && is_causal_lm).then(|| {
220226
BlockAllocator::new(
221227
max_batch_total_tokens,
222228
block_size,

router/src/server.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1134,12 +1134,21 @@ pub async fn run(
11341134
generation_health.clone(),
11351135
shard_info.clone(),
11361136
);
1137+
1138+
// For non-causal LMs, the max batch total tokens is equal to the max batch prefill tokens
1139+
let is_causal_lm = shard_info.supports_generation;
1140+
let effective_max_batch_total_tokens = if is_causal_lm {
1141+
max_batch_total_tokens
1142+
} else {
1143+
max_batch_prefill_tokens
1144+
};
1145+
11371146
let infer = Infer::new(
11381147
client.clone(),
11391148
validation,
11401149
waiting_served_ratio,
11411150
max_batch_prefill_tokens,
1142-
max_batch_total_tokens,
1151+
effective_max_batch_total_tokens,
11431152
max_waiting_tokens,
11441153
max_concurrent_requests,
11451154
max_active_adapters,
@@ -1154,6 +1163,7 @@ pub async fn run(
11541163
shard_info.speculate,
11551164
shard_info.preloaded_adapters,
11561165
prefix_caching,
1166+
is_causal_lm,
11571167
);
11581168

11591169
// Duration buckets

0 commit comments

Comments
 (0)