1
1
use std:: cmp:: max;
2
2
use std:: collections:: BTreeSet ;
3
3
use nohash_hasher:: IntMap ;
4
- use num:: integer:: Roots ;
5
4
use crate :: queue:: Entry ;
6
5
6
+
7
7
pub ( crate ) trait BatchType : Send + Sync + Clone + ' static {
8
8
type Stats : Default ;
9
9
10
10
/// Update batch statistics with an additional request
11
11
fn update_stats ( stats : & Self :: Stats , input_length : usize , output_length : usize ) -> Self :: Stats ;
12
12
/// Calculate worst-case max batch weight given batch statistics
13
- fn batch_max_weight ( stats : & Self :: Stats , batch_size : usize ) -> usize ;
13
+ fn batch_max_weight ( & self , stats : & Self :: Stats , batch_size : usize ) -> usize ;
14
14
/// Calculate initial max batch weight given batch statistics (based on input lengths only)
15
- fn batch_initial_weight ( stats : & Self :: Stats , batch_size : usize ) -> usize ;
15
+ fn batch_initial_weight ( & self , stats : & Self :: Stats , batch_size : usize ) -> usize ;
16
16
/// Calculate prefill batch weight given prefill batch statistics
17
- fn prefill_weight ( prefill_stats : & Self :: Stats , batch_size : usize ) -> usize ;
17
+ fn prefill_weight ( & self , prefill_stats : & Self :: Stats , batch_size : usize ) -> usize ;
18
18
/// Percentage of batch tokens that are padding
19
19
fn percent_padding ( prefill_stats : & Self :: Stats , batch_size : usize ) -> f32 ;
20
20
/// Indicate whether a hypothetical batch will exceed the combined weight limit
21
21
fn exceeds_weight (
22
- tree : & BTreeSet < ( usize , usize , usize ) > , max_total_weight : usize , current_output_len : usize
22
+ & self , tree : & BTreeSet < ( usize , usize , usize ) > , max_total_weight : usize , current_output_len : usize
23
23
) -> bool ;
24
24
/// Provide a count of tokens for a given batch, including padding tokens if applicable
25
25
fn count_tokens ( input_lengths : impl Iterator < Item =usize > , batch_size : usize ) -> usize ;
26
26
27
- /// max_prefill_weight to use when none is specified
28
- fn default_max_prefill_weight ( ) -> usize ;
29
-
30
27
/// Compute batch statistics given map of entries
31
28
fn compute_stats ( entries : & IntMap < u64 , Entry > ) -> Self :: Stats {
32
29
entries. iter ( ) . fold (
@@ -45,7 +42,10 @@ pub(crate) trait BatchType: Send + Sync + Clone + 'static {
45
42
46
43
/// Non-padded batch used in flash attention
47
44
#[ derive( Clone ) ]
48
- pub ( crate ) struct FlashBatch { }
45
+ pub ( crate ) struct FlashBatch {
46
+ pub ( crate ) prefill_gradient : f64 ,
47
+ pub ( crate ) nexttoken_gradient : f64 ,
48
+ }
49
49
50
50
impl BatchType for FlashBatch {
51
51
/// Keep track of total number of input and output tokens in the batch
@@ -58,37 +58,38 @@ impl BatchType for FlashBatch {
58
58
( total_in_tokens + input_length, total_out_tokens + output_length)
59
59
}
60
60
61
- fn batch_max_weight ( total_tokens : & Self :: Stats , _batch_size : usize ) -> usize {
61
+ fn batch_max_weight ( & self , total_tokens : & Self :: Stats , _batch_size : usize ) -> usize {
62
62
let ( total_in_tokens, total_out_tokens) = total_tokens;
63
- total_in_tokens + total_out_tokens
63
+ ( ( * total_in_tokens + * total_out_tokens) as f64 * self . nexttoken_gradient ) as usize
64
64
}
65
65
66
- fn batch_initial_weight ( ( total_in_tokens, _) : & Self :: Stats , _batch_size : usize ) -> usize {
67
- * total_in_tokens
66
+ fn batch_initial_weight ( & self , ( total_in_tokens, _) : & Self :: Stats , _batch_size : usize ) -> usize {
67
+ ( * total_in_tokens as f64 * self . nexttoken_gradient ) as usize
68
68
}
69
69
70
- fn prefill_weight ( ( total_in_tokens, _) : & Self :: Stats , _batch_size : usize ) -> usize {
71
- * total_in_tokens
70
+ fn prefill_weight ( & self , ( total_in_tokens, _) : & Self :: Stats , _batch_size : usize ) -> usize {
71
+ ( * total_in_tokens as f64 * self . prefill_gradient ) as usize
72
72
}
73
73
74
74
fn percent_padding ( _: & Self :: Stats , _batch_size : usize ) -> f32 {
75
75
0.0
76
76
}
77
77
78
78
fn exceeds_weight (
79
- tree : & BTreeSet < ( usize , usize , usize ) > , max_total_weight : usize , current_output_len : usize
79
+ & self , tree : & BTreeSet < ( usize , usize , usize ) > , max_total_weight : usize , current_output_len : usize
80
80
) -> bool {
81
81
let mut in_sum = 0 ;
82
82
// Work backwards from longest projected entry
83
83
for ( batch_size, ( out_len, in_len, _) ) in tree. iter ( ) . rev ( ) . enumerate ( ) {
84
+ let total_weight_limit = max_total_weight as f64 ;
84
85
let this_out_len = * out_len;
85
86
in_sum += * in_len;
86
87
// Only need to check segments with output_len > current_output_len
87
88
// will have been checked in a prior iteration
88
89
if this_out_len <= current_output_len {
89
90
// Check if we breach max space for this segment
90
- let token_count = in_sum + ( batch_size + 1 ) * this_out_len;
91
- if token_count > max_total_weight {
91
+ let seg_max_tokens = in_sum + ( batch_size + 1 ) * this_out_len;
92
+ if seg_max_tokens as f64 * self . nexttoken_gradient > total_weight_limit {
92
93
return true
93
94
}
94
95
}
@@ -100,14 +101,16 @@ impl BatchType for FlashBatch {
100
101
input_lengths. sum ( )
101
102
}
102
103
103
- fn default_max_prefill_weight ( ) -> usize {
104
- 8192
105
- }
106
104
}
107
105
108
106
/// Regular rectangular padded
109
107
#[ derive( Clone ) ]
110
- pub ( crate ) struct PaddedBatch { }
108
+ pub ( crate ) struct PaddedBatch {
109
+ pub ( crate ) prefill_linear_coef1 : f64 ,
110
+ pub ( crate ) prefill_quadratic_coef1 : f64 ,
111
+ pub ( crate ) prefill_quadratic_coef2 : f64 ,
112
+ pub ( crate ) nexttoken_gradient : f64 ,
113
+ }
111
114
112
115
impl BatchType for PaddedBatch {
113
116
/// Keep track of maximum input length, maximum output length, input token count
@@ -124,20 +127,26 @@ impl BatchType for PaddedBatch {
124
127
)
125
128
}
126
129
127
- fn batch_max_weight ( max_in_out_lengths : & Self :: Stats , batch_size : usize ) -> usize {
130
+ fn batch_max_weight ( & self , max_in_out_lengths : & Self :: Stats , batch_size : usize ) -> usize {
128
131
let ( max_input_length, max_output_length, _) = max_in_out_lengths;
129
- let max_seq_len = max_input_length + max_output_length;
130
- // Memory requirement roughly proportional to batch_size * seq_len^2
131
- batch_size * max_seq_len. pow ( 2 )
132
+ let seq_len_upper_bound = max_input_length + max_output_length;
133
+ ( ( seq_len_upper_bound * batch_size) as f64 * self . nexttoken_gradient ) as usize
132
134
}
133
135
134
- fn batch_initial_weight ( ( max_input_length, _, _) : & Self :: Stats , batch_size : usize ) -> usize {
135
- batch_size * max_input_length . pow ( 2 )
136
+ fn batch_initial_weight ( & self , ( max_input_length, _, _) : & Self :: Stats , batch_size : usize ) -> usize {
137
+ ( ( * max_input_length * batch_size) as f64 * self . nexttoken_gradient ) as usize
136
138
}
137
139
138
- fn prefill_weight ( ( max_input_length, _, _) : & Self :: Stats , batch_size : usize ) -> usize {
140
+ fn prefill_weight ( & self , ( max_input_length, _, _) : & Self :: Stats , batch_size : usize ) -> usize {
139
141
// Empirically, prefill latency is proportional to batch_size * seq_len^(3/2)
140
- batch_size * max_input_length. pow ( 3 ) . sqrt ( )
142
+ let input_tokens = batch_size * max_input_length;
143
+ let quad_input_tokens = ( input_tokens * max_input_length) as f64 ;
144
+ let input_tokens = input_tokens as f64 ;
145
+ let linear = input_tokens * self . prefill_linear_coef1 ;
146
+ let quadratic = input_tokens * self . prefill_quadratic_coef1 +
147
+ quad_input_tokens * self . prefill_quadratic_coef2 ;
148
+
149
+ f64:: max ( linear, quadratic) as usize
141
150
}
142
151
143
152
fn percent_padding ( ( max_input_length, _, total_in_tokens) : & Self :: Stats , batch_size : usize ) -> f32 {
@@ -149,17 +158,18 @@ impl BatchType for PaddedBatch {
149
158
}
150
159
151
160
fn exceeds_weight (
152
- tree : & BTreeSet < ( usize , usize , usize ) > , max_total_weight : usize , current_output_len : usize
161
+ & self , tree : & BTreeSet < ( usize , usize , usize ) > , max_total_weight : usize , current_output_len : usize
153
162
) -> bool {
163
+ let total_weight_limit = max_total_weight as f64 ;
154
164
let mut max_in_len = 0 ;
155
165
// Work backwards from longest projected entry
156
166
for ( batch_size, ( out_len, in_len, _) ) in tree. iter ( ) . rev ( ) . enumerate ( ) {
157
167
let this_out_len = * out_len;
158
168
max_in_len = max ( max_in_len, * in_len) ;
159
169
if this_out_len <= current_output_len {
160
170
// Check if we breach max space for this segment
161
- let seq_len = max_in_len + this_out_len;
162
- if seq_len . pow ( 2 ) * ( batch_size + 1 ) > max_total_weight {
171
+ let seg_max_tokens = ( max_in_len + this_out_len) * ( batch_size + 1 ) ;
172
+ if seg_max_tokens as f64 * self . nexttoken_gradient > total_weight_limit {
163
173
return true
164
174
}
165
175
}
@@ -170,8 +180,4 @@ impl BatchType for PaddedBatch {
170
180
fn count_tokens ( input_lengths : impl Iterator < Item =usize > , batch_size : usize ) -> usize {
171
181
input_lengths. max ( ) . unwrap_or ( 0 ) * batch_size
172
182
}
173
-
174
- fn default_max_prefill_weight ( ) -> usize {
175
- 300000
176
- }
177
183
}
0 commit comments