Skip to content

Commit a094488

Browse files
jvlunterentdoublep
authored andcommitted
Reduce PT2C Warmup time
This PR targets an issue created by Tom Parnell with the following description: "Currently in the PT2C warmup logic we essentially perform warmup twice, one with as_concat=False and again with as_concat=True. It was implemented this way because we see some differences between "normal" batches and batches that were created from concatenation. The warmup logic essentially tries to cover both of these two cases. Specifically, the differences between normal batches and post-concat baches are as follows: 1. Post-concat batches always have contiguous PKV tensors, whereas "normal" batches have contiguous PKV tensors almost all of the time but very occasionally (e..g, after very first token is generated) have non-contiguous PKV tensors. 2. Post-concat batches contain the decoder_attention_mask tensor (for encoder-decoder models) whereas for normal batches it is set to None. The issue relates to the following work: can we make some small code changes to essentially regularize these two cases? Since the PKV tensors are only rarely non-contiguous, can't we just force them to be contiguous before calling forward? There is some latency penalty to doing this but since most of the time it is not needed, we might be ok. Can be also define the decoder_attention_mask for "normal" batches. Again, perhaps there is some small latency overhead from this which needs to be evaluated. These changes may incur a potential latency cost but will have the benefit of halving the warmup time. The work here is to (a) implement these changes and (b) verify that the latency overhead is minimal." The update involves the required small code changes as described above. Co-authored-by: Thomas Parnell <[email protected]>
1 parent 92f1978 commit a094488

File tree

5 files changed

+40
-44
lines changed

5 files changed

+40
-44
lines changed

server/text_generation_server/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from text_generation_server.models.model import Model
6+
from text_generation_server.models.model import Model, PT2_COMPILE
77
from transformers.models.auto import modeling_auto
88

99
from text_generation_server.models.causal_lm import CausalLM
@@ -14,7 +14,7 @@
1414

1515
FLASH_ATTENTION = os.getenv("FLASH_ATTENTION", "false").lower() == "true"
1616

17-
__all__ = ["Model", "CausalLM", "Seq2SeqLM", "get_model", "FLASH_ATTENTION"]
17+
__all__ = ["Model", "CausalLM", "Seq2SeqLM", "get_model", "FLASH_ATTENTION", "PT2_COMPILE"]
1818

1919
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
2020
# in PyTorch 1.12 and later.

server/text_generation_server/models/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ def parse_kwargs(kwargs):
103103
if pkv is not None:
104104
if type(pkv) != type_pkv_dim0 or type(pkv[0]) != type_pkv_dim1:
105105
kwargs["past_key_values"] = type_pkv_dim0(type_pkv_dim1(t) for t in pkv)
106+
107+
for t in pkv:
108+
for x in t:
109+
strides = list(x.stride())
110+
if strides != sorted(strides, reverse=True):
111+
x.data = x.data.clone(memory_format=torch.contiguous_format)
112+
106113
return kwargs
107114

108115
def override_forward_with_compile(self, *args, **kwargs):
@@ -113,7 +120,6 @@ def override_forward_with_run(self, *args, **kwargs):
113120
kwargs = parse_kwargs(kwargs)
114121
return run_forward(*args, **kwargs)
115122

116-
self.compiled = True
117123
self.model.forward = types.MethodType(override_forward_with_compile, self.model)
118124
self.model.run_forward = types.MethodType(override_forward_with_run, self.model)
119125

server/text_generation_server/models/seq2seq_lm.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from transformers.modeling_outputs import BaseModelOutput
1212

13-
from text_generation_server.models.model import Model, CUDA_PAD_TO_MULT_OF_8
13+
from text_generation_server.models.model import Model, CUDA_PAD_TO_MULT_OF_8, PT2_COMPILE
1414
from text_generation_server.models.types import Batch, GenerateError
1515
from text_generation_server.pb import generate_pb2
1616
from text_generation_server.prompt_cache import PrefixCache
@@ -207,7 +207,14 @@ def from_pb(
207207
decoder_input_ids[:, -1] = tokenizer.bos_token_id
208208
else:
209209
decoder_inputs_embeds = None
210-
decoder_attention_mask = None
210+
if PT2_COMPILE:
211+
decoder_attention_mask = attention_mask.new_zeros(
212+
batch_size, max_decoder_input_length + padding_right_offset
213+
)
214+
decoder_attention_mask[:, 0] = 1
215+
else:
216+
decoder_attention_mask = None
217+
211218
# Each decoder sequence only contains the bos_token
212219
# so decoder_input_ids is a torch tensor of size [batch_size, 1]
213220
decoder_input_ids = input_ids.new_full((batch_size, 1), tokenizer.bos_token_id)

server/text_generation_server/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import List, Optional
1616

1717
from text_generation_server.cache import Cache
18-
from text_generation_server.models import Model, get_model, Seq2SeqLM
18+
from text_generation_server.models import Model, get_model, Seq2SeqLM, PT2_COMPILE
1919
from text_generation_server.models.flash_causal_lm import FlashCausalLM
2020
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
2121
from text_generation_server.pb.generate_pb2 import ModelInfoResponse
@@ -305,7 +305,7 @@ async def serve_inner(
305305
t = threading.Thread(target=partial(log_gpu_stats, device, interval))
306306
t.start()
307307

308-
if model.compiled:
308+
if PT2_COMPILE:
309309
# trigger pt2 compile for variety of tensor shapes
310310
print("Warming up PyTorch 2 compile...")
311311
warmup_t0 = time.time()

server/text_generation_server/utils/warmup.py

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ def __force_contiguous(x):
3939
x.data = x.data.contiguous(memory_format=torch.channels_last).contiguous()
4040
return x
4141

42-
def __eval_shape(batch_size: int, sequence_length: int, num_new_tokens: int, as_concat: bool = False):
42+
def __eval_shape(batch_size: int, sequence_length: int, num_new_tokens: int):
4343

4444
if verbose:
45-
print(">> evaluating shape (batch_size: %d, sequence_length: %d, num_new_tokens: %d), as_concat: %d" % (batch_size, sequence_length, num_new_tokens, as_concat))
45+
print(">> evaluating shape (batch_size: %d, sequence_length: %d, num_new_tokens: %d)" % (batch_size, sequence_length, num_new_tokens))
4646

4747
input_length = sequence_length - num_new_tokens
4848

@@ -59,27 +59,16 @@ def __eval_shape(batch_size: int, sequence_length: int, num_new_tokens: int, as_
5959
use_position_ids=model.use_position_ids,
6060
)
6161

62-
if as_concat and has_decoder_attention_mask:
63-
batch.decoder_attention_mask = batch.attention_mask.new_zeros(
64-
batch_size,
65-
batch.max_decoder_input_length + batch.padding_right_offset
66-
)
67-
batch.decoder_attention_mask[:, 0:-batch.padding_right_offset] = 1
68-
6962
model.generate_token(
7063
batch, first=True, for_concat=False,
7164
)
7265

7366
for i in range(num_new_tokens-1):
74-
75-
if as_concat:
76-
batch.past_key_values = tuple(tuple(__force_contiguous(t) for t in layer) for layer in batch.past_key_values)
77-
7867
model.generate_token(batch)
7968

80-
def __safe_eval_shape(batch_size: int, sequence_length: int, num_new_tokens: int, as_concat: bool = False):
69+
def __safe_eval_shape(batch_size: int, sequence_length: int, num_new_tokens: int):
8170
try:
82-
__eval_shape(batch_size, sequence_length, num_new_tokens, as_concat)
71+
__eval_shape(batch_size, sequence_length, num_new_tokens)
8372
except Exception as e:
8473
print(">> caught exception: ", e)
8574

@@ -101,30 +90,26 @@ def __max_new_tokens_for_sequence_length(sequence_length: int):
10190
if verbose:
10291
print("[Phase 1] Probing boundaries.")
10392

104-
for as_concat in [True, False]:
105-
for batch_size in [1, max_batch_size]:
106-
max_sequence_length_for_batch_size = __max_sequence_length_for_batch_size(batch_size)
107-
for sequence_length in [2, 3, max_sequence_length_for_batch_size]:
93+
for batch_size in [1, max_batch_size]:
94+
max_sequence_length_for_batch_size = __max_sequence_length_for_batch_size(batch_size)
95+
for sequence_length in [2, 3, max_sequence_length_for_batch_size]:
96+
__safe_eval_shape(
97+
batch_size=batch_size,
98+
sequence_length=sequence_length,
99+
num_new_tokens=1,
100+
)
101+
if sequence_length > 2:
102+
__safe_eval_shape(
103+
batch_size=batch_size,
104+
sequence_length=sequence_length,
105+
num_new_tokens=2,
106+
)
107+
if sequence_length > 3:
108108
__safe_eval_shape(
109109
batch_size=batch_size,
110110
sequence_length=sequence_length,
111-
num_new_tokens=1,
112-
as_concat=as_concat,
111+
num_new_tokens=__max_new_tokens_for_sequence_length(sequence_length),
113112
)
114-
if sequence_length > 2:
115-
__safe_eval_shape(
116-
batch_size=batch_size,
117-
sequence_length=sequence_length,
118-
num_new_tokens=2,
119-
as_concat=as_concat,
120-
)
121-
if sequence_length > 3:
122-
__safe_eval_shape(
123-
batch_size=batch_size,
124-
sequence_length=sequence_length,
125-
num_new_tokens=__max_new_tokens_for_sequence_length(sequence_length),
126-
as_concat=as_concat,
127-
)
128113

129114
if verbose:
130115
print("[Phase 2] Probing random valid tensor shapes.")
@@ -142,12 +127,10 @@ def __max_new_tokens_for_sequence_length(sequence_length: int):
142127
rs = np.random.RandomState(seed=42)
143128
for i in range(n_samples):
144129
shape = valid_shapes[rs.randint(low=0, high=len(valid_shapes))]
145-
as_concat = rs.choice([True, False])
146130
__safe_eval_shape(
147131
batch_size=shape[0],
148132
sequence_length=shape[1],
149133
num_new_tokens=shape[2],
150-
as_concat=as_concat,
151134
)
152135
if verbose:
153136
print(">> n_samples: %3d, n_new_compiles: %3d" % (i+1, model.n_kernels-n_compiles))

0 commit comments

Comments
 (0)