1
1
from typing import Optional
2
2
from text_generation_server .pb import generate_pb2
3
3
import numpy as np
4
- from dataclasses import fields
5
4
import torch
6
5
6
+
7
7
def pt2_compile_warmup (
8
8
model : 'Model' ,
9
9
max_batch_size : int ,
@@ -13,10 +13,6 @@ def pt2_compile_warmup(
13
13
n_samples : int = 10 ,
14
14
verbose : bool = False
15
15
):
16
-
17
-
18
- has_decoder_attention_mask = "decoder_attention_mask" in [ x .name for x in fields (model .batch_type ) ]
19
-
20
16
text = "test " * 10_000
21
17
22
18
def __generate_prefill_request (batch_size : int , in_tokens : int , num_new_tokens : int ):
@@ -40,9 +36,11 @@ def __force_contiguous(x):
40
36
return x
41
37
42
38
def __eval_shape (batch_size : int , sequence_length : int , num_new_tokens : int ):
43
-
44
39
if verbose :
45
- print (">> evaluating shape (batch_size: %d, sequence_length: %d, num_new_tokens: %d)" % (batch_size , sequence_length , num_new_tokens ))
40
+ print (
41
+ ">> evaluating shape (batch_size: %d, sequence_length: %d, num_new_tokens: %d)"
42
+ % (batch_size , sequence_length , num_new_tokens )
43
+ )
46
44
47
45
input_length = sequence_length - num_new_tokens
48
46
@@ -59,33 +57,24 @@ def __eval_shape(batch_size: int, sequence_length: int, num_new_tokens: int):
59
57
use_position_ids = model .use_position_ids ,
60
58
)
61
59
62
- model .generate_token (
63
- batch , first = True , for_concat = False ,
64
- )
60
+ model .generate_token (batch , first = True , for_concat = False )
65
61
66
- for i in range (num_new_tokens - 1 ):
62
+ for i in range (num_new_tokens - 1 ):
67
63
model .generate_token (batch )
68
64
69
65
def __safe_eval_shape (batch_size : int , sequence_length : int , num_new_tokens : int ):
70
66
try :
71
67
__eval_shape (batch_size , sequence_length , num_new_tokens )
72
- except Exception as e :
73
- print (">> caught exception : " , e )
68
+ except torch . cuda . OutOfMemoryError as e :
69
+ print (">> caught OOM error : " , e )
74
70
75
71
def __max_sequence_length_for_batch_size (batch_size : int ):
76
- if max_batch_weight is not None :
77
- return min (
78
- max_sequence_length ,
79
- int (np .floor (np .sqrt (max_batch_weight / batch_size )))
80
- )
81
- else :
82
- return max_sequence_length
72
+ return max_sequence_length if max_batch_weight is None else min (
73
+ max_sequence_length , int (np .floor (np .sqrt (max_batch_weight / batch_size )))
74
+ )
83
75
84
76
def __max_new_tokens_for_sequence_length (sequence_length : int ):
85
- return min (
86
- max_new_tokens ,
87
- sequence_length - 1
88
- )
77
+ return min (max_new_tokens , sequence_length - 1 )
89
78
90
79
if verbose :
91
80
print ("[Phase 1] Probing boundaries." )
@@ -117,11 +106,11 @@ def __max_new_tokens_for_sequence_length(sequence_length: int):
117
106
n_compiles = model .n_kernels
118
107
119
108
valid_shapes = []
120
- for batch_size in range (1 , 1 + max_batch_size ):
109
+ for batch_size in range (1 , 1 + max_batch_size ):
121
110
this_max_sequence_length = __max_sequence_length_for_batch_size (batch_size )
122
- for sequence_length in range (1 , 1 + this_max_sequence_length ):
111
+ for sequence_length in range (1 , 1 + this_max_sequence_length ):
123
112
this_max_new_tokens = __max_new_tokens_for_sequence_length (sequence_length )
124
- for new_tokens in range (1 , 1 + this_max_new_tokens ):
113
+ for new_tokens in range (1 , 1 + this_max_new_tokens ):
125
114
valid_shapes .append ((batch_size , sequence_length , new_tokens ))
126
115
127
116
rs = np .random .RandomState (seed = 42 )
0 commit comments