Skip to content

Commit 9df74cd

Browse files
committed
fix: Don't swallow non-OOM PT compile warmup exceptions
Also adjust some formatting in warmup.py
1 parent 4c96be0 commit 9df74cd

File tree

1 file changed

+16
-27
lines changed

1 file changed

+16
-27
lines changed

server/text_generation_server/utils/warmup.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import Optional
22
from text_generation_server.pb import generate_pb2
33
import numpy as np
4-
from dataclasses import fields
54
import torch
65

6+
77
def pt2_compile_warmup(
88
model: 'Model',
99
max_batch_size: int,
@@ -13,10 +13,6 @@ def pt2_compile_warmup(
1313
n_samples: int = 10,
1414
verbose: bool = False
1515
):
16-
17-
18-
has_decoder_attention_mask = "decoder_attention_mask" in [ x.name for x in fields(model.batch_type) ]
19-
2016
text = "test " * 10_000
2117

2218
def __generate_prefill_request(batch_size: int, in_tokens: int, num_new_tokens: int):
@@ -40,9 +36,11 @@ def __force_contiguous(x):
4036
return x
4137

4238
def __eval_shape(batch_size: int, sequence_length: int, num_new_tokens: int):
43-
4439
if verbose:
45-
print(">> evaluating shape (batch_size: %d, sequence_length: %d, num_new_tokens: %d)" % (batch_size, sequence_length, num_new_tokens))
40+
print(
41+
">> evaluating shape (batch_size: %d, sequence_length: %d, num_new_tokens: %d)"
42+
% (batch_size, sequence_length, num_new_tokens)
43+
)
4644

4745
input_length = sequence_length - num_new_tokens
4846

@@ -59,33 +57,24 @@ def __eval_shape(batch_size: int, sequence_length: int, num_new_tokens: int):
5957
use_position_ids=model.use_position_ids,
6058
)
6159

62-
model.generate_token(
63-
batch, first=True, for_concat=False,
64-
)
60+
model.generate_token(batch, first=True, for_concat=False)
6561

66-
for i in range(num_new_tokens-1):
62+
for i in range(num_new_tokens - 1):
6763
model.generate_token(batch)
6864

6965
def __safe_eval_shape(batch_size: int, sequence_length: int, num_new_tokens: int):
7066
try:
7167
__eval_shape(batch_size, sequence_length, num_new_tokens)
72-
except Exception as e:
73-
print(">> caught exception: ", e)
68+
except torch.cuda.OutOfMemoryError as e:
69+
print(">> caught OOM error: ", e)
7470

7571
def __max_sequence_length_for_batch_size(batch_size: int):
76-
if max_batch_weight is not None:
77-
return min(
78-
max_sequence_length,
79-
int(np.floor(np.sqrt(max_batch_weight/batch_size)))
80-
)
81-
else:
82-
return max_sequence_length
72+
return max_sequence_length if max_batch_weight is None else min(
73+
max_sequence_length, int(np.floor(np.sqrt(max_batch_weight/batch_size)))
74+
)
8375

8476
def __max_new_tokens_for_sequence_length(sequence_length: int):
85-
return min(
86-
max_new_tokens,
87-
sequence_length-1
88-
)
77+
return min(max_new_tokens, sequence_length - 1)
8978

9079
if verbose:
9180
print("[Phase 1] Probing boundaries.")
@@ -117,11 +106,11 @@ def __max_new_tokens_for_sequence_length(sequence_length: int):
117106
n_compiles = model.n_kernels
118107

119108
valid_shapes = []
120-
for batch_size in range(1, 1+max_batch_size):
109+
for batch_size in range(1, 1 + max_batch_size):
121110
this_max_sequence_length = __max_sequence_length_for_batch_size(batch_size)
122-
for sequence_length in range(1, 1+this_max_sequence_length):
111+
for sequence_length in range(1, 1 + this_max_sequence_length):
123112
this_max_new_tokens = __max_new_tokens_for_sequence_length(sequence_length)
124-
for new_tokens in range(1, 1+this_max_new_tokens):
113+
for new_tokens in range(1, 1 + this_max_new_tokens):
125114
valid_shapes.append((batch_size, sequence_length, new_tokens))
126115

127116
rs = np.random.RandomState(seed=42)

0 commit comments

Comments
 (0)