Skip to content

Commit 31d76e2

Browse files
fix(batching): Avoid theoretical hang in batcher loop (IBM#5)
- Avoid theoretical hang in batcher loop - Avoid a couple of clones in the router generate method - Keep attention mask tensors as integers - Remove num_heads attribute Co-authored-by: OlivierDehaene <[email protected]>
1 parent daa1d81 commit 31d76e2

File tree

8 files changed

+11
-21
lines changed

8 files changed

+11
-21
lines changed

router/src/batcher.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,9 @@ async fn batching_task(
104104
// Get the next batch from the DB
105105
// This batch might be smaller than the maximum batch size if there are not enough requests
106106
// waiting in the DB
107-
let mut waiting_tokens = 0;
108-
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
107+
while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
109108
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
110-
waiting_tokens += 1;
109+
let mut waiting_tokens = 1;
111110

112111
// We loop until we do not receive any cached batch from the inference server (== until
113112
// all requests have met their stopping criteria)
@@ -131,11 +130,11 @@ async fn batching_task(
131130
if let Some((new_request_ids, new_batch)) =
132131
db.next_batch(min_size, max_batch_size)
133132
{
134-
// Reset waiting counter
135-
waiting_tokens = 0;
136133
// Generate one token for this new batch to have the attention past in cache
137134
let new_cached_batch =
138135
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
136+
// Reset waiting counter
137+
waiting_tokens = 1;
139138
// Extend current batch with the new batch
140139
if let Some(new_cached_batch) = new_cached_batch {
141140
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));

router/src/server.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,7 @@ async fn generate(
9090
// Validate request
9191
let (input_length, validated_request) = state
9292
.validation
93-
// FIXME: can't we get rid of the cloning here??
94-
.validate(GenerateRequest {
95-
inputs: req.inputs.clone(),
96-
parameters: req.parameters.clone(),
97-
})
93+
.validate(req.0)
9894
.await
9995
.map_err(|err| {
10096
tracing::error!("{}", err.to_string());

router/src/validation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ type ValidationRequest = (
155155
pub enum ValidationError {
156156
#[error("temperature must be strictly positive")]
157157
Temperature,
158-
#[error("top_p must be >= 0.0 or < 1.0")]
158+
#[error("top_p must be > 0.0 and <= 1.0")]
159159
TopP,
160160
#[error("top_k must be strictly positive")]
161161
TopK,

server/text_generation/models/bloom.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def __init__(self, model_name: str, quantize: bool = False):
8282
torch.distributed.barrier(group=self.process_group)
8383
super(CausalLM, self).__init__(
8484
tokenizer=tokenizer,
85-
num_heads=config.n_head // self.process_group.size(),
8685
device=device,
8786
)
8887

server/text_generation/models/causal_lm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,6 @@ def __init__(self, model_name: str, quantize=False):
251251

252252
super(CausalLM, self).__init__(
253253
tokenizer=tokenizer,
254-
num_heads=self.model.config.num_attention_heads,
255254
device=device,
256255
)
257256

@@ -358,7 +357,7 @@ def generate_token(
358357
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
359358
next_batch_past_key_values = [
360359
[
361-
t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices]
360+
t.view(batch.size, -1, *t.shape[-2:])[next_batch_keep_indices]
362361
for t in layer
363362
]
364363
for layer in past
@@ -381,7 +380,7 @@ def generate_token(
381380
next_batch_attention_mask = torch.cat(
382381
[
383382
next_batch_attention_mask,
384-
torch.ones((next_batch_size, 1)).to(self.device),
383+
next_batch_attention_mask.new_ones(next_batch_size, 1),
385384
],
386385
dim=1,
387386
)

server/text_generation/models/galactica.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def __init__(self, model_name: str, quantize: bool = False):
185185
torch.distributed.barrier(group=self.process_group)
186186
super(CausalLM, self).__init__(
187187
tokenizer=tokenizer,
188-
num_heads=config.num_attention_heads // self.process_group.size(),
189188
device=device,
190189
)
191190

server/text_generation/models/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010

1111

1212
class Model(ABC):
13-
def __init__(self, tokenizer: Tokenizer, num_heads: int, device: torch.device):
13+
def __init__(self, tokenizer: Tokenizer, device: torch.device):
1414
self.tokenizer = tokenizer
15-
self.num_heads = num_heads
1615
self.device = device
1716

1817
@property

server/text_generation/models/seq2seq_lm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def from_pb(
8787
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
8888
).to(device)
8989
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
90-
decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1)
90+
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
9191

9292
return cls(
9393
batch_id=pb.id,
@@ -319,7 +319,6 @@ def __init__(self, model_name: str, quantize=False):
319319

320320
super(Seq2SeqLM, self).__init__(
321321
tokenizer=tokenizer,
322-
num_heads=self.model.config.num_attention_heads,
323322
device=device,
324323
)
325324

@@ -499,7 +498,7 @@ def generate_token(
499498
next_batch_decoder_attention_mask = torch.cat(
500499
[
501500
next_batch_decoder_attention_mask,
502-
torch.ones((next_batch_size, 1)).to(self.device),
501+
next_batch_decoder_attention_mask.new_ones(next_batch_size, 1),
503502
],
504503
dim=1,
505504
)

0 commit comments

Comments
 (0)