Skip to content

Commit f17f870

Browse files
tdoublepmaxdebaysernjhill
committed
Automated modelling of memory scaling behaviour
These changes introduce a memory scaling model, parameterized for the loaded model via measurement at startup. This allows for maximal batch sizing given the available GPU memory, while avoiding OOMs. It means the manually-configured MAX_BATCH_WEIGHT and MAX_PREFILL_WEIGHT env vars are no longer used. We instead have a BATCH_SAFETY_MARGIN percentage env var with a default of 20, it should hopefully rarely be necessary to override this. Co-authored-by: Maximilien Philippe Marie de Bayser <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 1ffc616 commit f17f870

File tree

18 files changed

+830
-219
lines changed

18 files changed

+830
-219
lines changed

integration_tests/text_generation_tests/test_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@ def start_server(
6161
# Reduce this so we can more easily test limit behaviour
6262
"--max-sequence-length", "200",
6363
"--max-new-tokens", "169",
64-
"--max-batch-weight", "80000",
6564
]
6665

6766
if output_special_tokens:
6867
args.append("--output-special-tokens")
6968

7069
env = os.environ.copy()
7170
env["RUST_BACKTRACE"] = "full"
71+
env["ESTIMATE_MEMORY"] = "manual"
7272
env["PREFIX_STORE_PATH"] = os.path.join(TESTS_DIR, "prompt_prefixes")
7373
if not include_cache_env_vars:
7474
env.pop("TRANSFORMERS_CACHE", None)

launcher/src/main.rs

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use std::{fs, io};
1616
use std::env::VarError;
1717
use std::ffi::OsString;
1818
use std::os::unix::process::CommandExt;
19-
use tracing::info;
19+
use tracing::{info, warn};
2020

2121
// In most cases this gives the best performance for inferencing
2222
const DEFAULT_PYTORCH_CUDA_ALLOC_CONF: &'static str = "expandable_segments:True";
@@ -47,12 +47,10 @@ struct Args {
4747
max_new_tokens: usize,
4848
#[clap(default_value = "12", long, env)]
4949
max_batch_size: usize,
50-
#[clap(default_value = None, long, env)]
51-
max_batch_weight: Option<usize>,
52-
#[clap(default_value = None, long, env)]
53-
max_prefill_weight: Option<usize>,
5450
#[clap(default_value = "0.2", long, env)]
5551
max_prefill_padding: f32,
52+
#[clap(default_value = "20", long, env)]
53+
batch_safety_margin: usize,
5654
#[clap(default_value = "24", long, env)]
5755
max_waiting_tokens: usize,
5856
#[clap(default_value = "3000", long, short, env)]
@@ -112,6 +110,20 @@ fn main() -> ExitCode {
112110
&args.model_name, args.revision.as_deref()
113111
).expect("Could not find tokenizer for model");
114112

113+
match env::var("MAX_BATCH_WEIGHT") {
114+
Ok(max_batch_weight) if !max_batch_weight.trim().is_empty() => {
115+
warn!("MAX_BATCH_WEIGHT is set to {max_batch_weight} but this parameter will be ignored.");
116+
}
117+
_ => {}
118+
}
119+
120+
match env::var("MAX_PREFILL_WEIGHT") {
121+
Ok(max_prefill_weight) if !max_prefill_weight.trim().is_empty() => {
122+
warn!("MAX_PREFILL_WEIGHT is set to {max_prefill_weight} but this parameter will be ignored.");
123+
}
124+
_ => {}
125+
}
126+
115127
// Set PYTORCH_CUDA_ALLOC_CONF to default value if it's not set in the environment
116128
let cuda_alloc_conf = match env::var("PYTORCH_CUDA_ALLOC_CONF") {
117129
Err(VarError::NotPresent) if DEFAULT_PYTORCH_CUDA_ALLOC_CONF == "" => None,
@@ -164,7 +176,7 @@ fn main() -> ExitCode {
164176
args.max_sequence_length,
165177
args.max_new_tokens,
166178
args.max_batch_size,
167-
args.max_batch_weight,
179+
args.batch_safety_margin,
168180
args.shard_uds_path,
169181
args.cuda_process_memory_fraction,
170182
cuda_alloc_conf,
@@ -237,15 +249,6 @@ fn main() -> ExitCode {
237249
tokenizer_path,
238250
];
239251

240-
if let Some(max_batch_weight) = args.max_batch_weight {
241-
argv.push("--max-batch-weight".to_string());
242-
argv.push(max_batch_weight.to_string());
243-
}
244-
if let Some(max_prefill_weight) = args.max_prefill_weight {
245-
argv.push("--max-prefill-weight".to_string());
246-
argv.push(max_prefill_weight.to_string());
247-
}
248-
249252
if let Some(path) = args.tls_key_path {
250253
argv.push("--tls-key-path".to_string());
251254
argv.push(path);
@@ -395,7 +398,7 @@ fn shard_manager(
395398
max_sequence_length: usize,
396399
max_new_tokens: usize,
397400
max_batch_size: usize,
398-
max_batch_weight: Option<usize>,
401+
batch_safety_margin: usize,
399402
uds_path: String,
400403
cuda_process_memory_fraction: f32,
401404
cuda_alloc_conf: Option<&str>,
@@ -428,6 +431,8 @@ fn shard_manager(
428431
max_new_tokens.to_string(),
429432
"--max-batch-size".to_string(),
430433
max_batch_size.to_string(),
434+
"--batch-safety-margin".to_string(),
435+
batch_safety_margin.to_string(),
431436
"--uds-path".to_string(),
432437
uds_path,
433438
"--cuda-process-memory-fraction".to_string(),
@@ -455,12 +460,6 @@ fn shard_manager(
455460
shard_argv.push(revision);
456461
}
457462

458-
// Maximum batch weight - used only for PT2 compile
459-
if let Some(max_batch_weight) = max_batch_weight {
460-
shard_argv.push("--max-batch-weight".to_string());
461-
shard_argv.push(max_batch_weight.to_string());
462-
}
463-
464463
// Copy current process env
465464
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
466465

proto/generate.proto

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,26 @@ message ClearCacheResponse {}
4141
/// Empty request
4242
message ModelInfoRequest {}
4343

44+
message MemoryScalingModel {
45+
float prefill_linear_coef0 = 1;
46+
float prefill_quadratic_coef0 = 2;
47+
float prefill_quadratic_coef1 = 3;
48+
float nexttoken_linear_coef0 = 4;
49+
float nexttoken_linear_coef1 = 5;
50+
uint64 weight_limit = 6;
51+
}
52+
4453
message ModelInfoResponse {
4554
enum ModelType {
4655
CAUSAL_LM = 0;
4756
SEQ2SEQ_LM = 1;
4857
}
49-
5058
ModelType model_type = 1;
5159
uint32 eos_token = 2;
5260
/// Whether batches are rectangular/padded (false for flash attention)
5361
bool batch_padding = 3;
62+
/// Memory scaling model
63+
MemoryScalingModel memory_scaling_model = 4;
5464
}
5565

5666
message NextTokenChooserParameters {
@@ -211,4 +221,4 @@ message PrefixLookupRequest {
211221
/// Empty response
212222
message PrefixLookupResponse {
213223
uint32 prefix_length = 1;
214-
}
224+
}

router/client/src/client.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,20 @@ impl Client {
7676

7777
/// Get shard model info
7878
#[instrument(skip(self))]
79-
pub async fn model_info(&mut self) -> Result<(ModelType, u32, bool)> {
79+
pub async fn model_info(&mut self) -> Result<(ModelType, u32, bool, MemoryScalingModel)> {
8080
let request = tonic::Request::new(ModelInfoRequest {});
8181
let response = self.stub
8282
.model_info(request)
8383
.instrument(info_span!("model_info"))
8484
.await?
8585
.into_inner();
8686
ModelType::try_from(response.model_type)
87-
.map(|mt| (mt, response.eos_token, response.batch_padding))
87+
.map(|mt| (
88+
mt,
89+
response.eos_token,
90+
response.batch_padding,
91+
response.memory_scaling_model.unwrap(),
92+
))
8893
.map_err(|_| ClientError::Generation("Unrecognized model type".to_string()))
8994
}
9095

router/client/src/sharded_client.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use tonic::transport::Uri;
77
use crate::client::GenerateTokenResponse;
88
use crate::pb::generate::v1::CachedBatch;
99
use crate::pb::generate::v1::model_info_response::ModelType;
10+
use crate::pb::generate::v1::MemoryScalingModel;
1011
use crate::sharded_client::Request::{NextToken, Prefill};
1112

1213
#[derive(Clone, Debug)]
@@ -138,8 +139,8 @@ impl ShardedClient {
138139
}
139140

140141
/// Get shard model info
141-
pub async fn model_info(&mut self) -> Result<(bool, u32, bool)> {
142+
pub async fn model_info(&mut self) -> Result<(bool, u32, bool, MemoryScalingModel)> {
142143
self.clients[0].model_info().await
143-
.map(|(mt, eos, bpad)| (mt == ModelType::Seq2seqLm, eos, bpad))
144+
.map(|(mt, eos, bpad, mem_model)| (mt == ModelType::Seq2seqLm, eos, bpad, mem_model))
144145
}
145146
}

router/src/batch_types.rs

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,29 @@
11
use std::cmp::max;
22
use std::collections::BTreeSet;
33
use nohash_hasher::IntMap;
4-
use num::integer::Roots;
54
use crate::queue::Entry;
65

6+
77
pub(crate) trait BatchType: Send + Sync + Clone + 'static {
88
type Stats: Default;
99

1010
/// Update batch statistics with an additional request
1111
fn update_stats(stats: &Self::Stats, input_length: usize, output_length: usize) -> Self::Stats;
1212
/// Calculate worst-case max batch weight given batch statistics
13-
fn batch_max_weight(stats: &Self::Stats, batch_size: usize) -> usize;
13+
fn batch_max_weight(&self, stats: &Self::Stats, batch_size: usize) -> usize;
1414
/// Calculate initial max batch weight given batch statistics (based on input lengths only)
15-
fn batch_initial_weight(stats: &Self::Stats, batch_size: usize) -> usize;
15+
fn batch_initial_weight(&self, stats: &Self::Stats, batch_size: usize) -> usize;
1616
/// Calculate prefill batch weight given prefill batch statistics
17-
fn prefill_weight(prefill_stats: &Self::Stats, batch_size: usize) -> usize;
17+
fn prefill_weight(&self, prefill_stats: &Self::Stats, batch_size: usize) -> usize;
1818
/// Percentage of batch tokens that are padding
1919
fn percent_padding(prefill_stats: &Self::Stats, batch_size: usize) -> f32;
2020
/// Indicate whether a hypothetical batch will exceed the combined weight limit
2121
fn exceeds_weight(
22-
tree: &BTreeSet<(usize, usize, usize)>, max_total_weight: usize, current_output_len: usize
22+
&self, tree: &BTreeSet<(usize, usize, usize)>, max_total_weight: usize, current_output_len: usize
2323
) -> bool;
2424
/// Provide a count of tokens for a given batch, including padding tokens if applicable
2525
fn count_tokens(input_lengths: impl Iterator<Item=usize>, batch_size: usize) -> usize;
2626

27-
/// max_prefill_weight to use when none is specified
28-
fn default_max_prefill_weight() -> usize;
29-
3027
/// Compute batch statistics given map of entries
3128
fn compute_stats(entries: &IntMap<u64, Entry>) -> Self::Stats {
3229
entries.iter().fold(
@@ -45,7 +42,10 @@ pub(crate) trait BatchType: Send + Sync + Clone + 'static {
4542

4643
/// Non-padded batch used in flash attention
4744
#[derive(Clone)]
48-
pub(crate) struct FlashBatch {}
45+
pub(crate) struct FlashBatch {
46+
pub(crate) prefill_gradient: f64,
47+
pub(crate) nexttoken_gradient: f64,
48+
}
4949

5050
impl BatchType for FlashBatch {
5151
/// Keep track of total number of input and output tokens in the batch
@@ -58,37 +58,38 @@ impl BatchType for FlashBatch {
5858
(total_in_tokens + input_length, total_out_tokens + output_length)
5959
}
6060

61-
fn batch_max_weight(total_tokens: &Self::Stats, _batch_size: usize) -> usize {
61+
fn batch_max_weight(&self, total_tokens: &Self::Stats, _batch_size: usize) -> usize {
6262
let (total_in_tokens, total_out_tokens) = total_tokens;
63-
total_in_tokens + total_out_tokens
63+
((*total_in_tokens + *total_out_tokens) as f64 * self.nexttoken_gradient) as usize
6464
}
6565

66-
fn batch_initial_weight((total_in_tokens, _): &Self::Stats, _batch_size: usize) -> usize {
67-
*total_in_tokens
66+
fn batch_initial_weight(&self, (total_in_tokens, _): &Self::Stats, _batch_size: usize) -> usize {
67+
(*total_in_tokens as f64 * self.nexttoken_gradient) as usize
6868
}
6969

70-
fn prefill_weight((total_in_tokens, _): &Self::Stats, _batch_size: usize) -> usize {
71-
*total_in_tokens
70+
fn prefill_weight(&self, (total_in_tokens, _): &Self::Stats, _batch_size: usize) -> usize {
71+
(*total_in_tokens as f64 * self.prefill_gradient) as usize
7272
}
7373

7474
fn percent_padding(_: &Self::Stats, _batch_size: usize) -> f32 {
7575
0.0
7676
}
7777

7878
fn exceeds_weight(
79-
tree: &BTreeSet<(usize, usize, usize)>, max_total_weight: usize, current_output_len: usize
79+
&self, tree: &BTreeSet<(usize, usize, usize)>, max_total_weight: usize, current_output_len: usize
8080
) -> bool {
8181
let mut in_sum = 0;
8282
// Work backwards from longest projected entry
8383
for (batch_size, (out_len, in_len, _)) in tree.iter().rev().enumerate() {
84+
let total_weight_limit = max_total_weight as f64;
8485
let this_out_len = *out_len;
8586
in_sum += *in_len;
8687
// Only need to check segments with output_len > current_output_len
8788
// will have been checked in a prior iteration
8889
if this_out_len <= current_output_len {
8990
// Check if we breach max space for this segment
90-
let token_count = in_sum + (batch_size + 1) * this_out_len;
91-
if token_count > max_total_weight {
91+
let seg_max_tokens = in_sum + (batch_size + 1) * this_out_len;
92+
if seg_max_tokens as f64 * self.nexttoken_gradient > total_weight_limit {
9293
return true
9394
}
9495
}
@@ -100,14 +101,16 @@ impl BatchType for FlashBatch {
100101
input_lengths.sum()
101102
}
102103

103-
fn default_max_prefill_weight() -> usize {
104-
8192
105-
}
106104
}
107105

108106
/// Regular rectangular padded
109107
#[derive(Clone)]
110-
pub(crate) struct PaddedBatch {}
108+
pub(crate) struct PaddedBatch {
109+
pub(crate) prefill_linear_coef1: f64,
110+
pub(crate) prefill_quadratic_coef1: f64,
111+
pub(crate) prefill_quadratic_coef2: f64,
112+
pub(crate) nexttoken_gradient: f64,
113+
}
111114

112115
impl BatchType for PaddedBatch {
113116
/// Keep track of maximum input length, maximum output length, input token count
@@ -124,20 +127,26 @@ impl BatchType for PaddedBatch {
124127
)
125128
}
126129

127-
fn batch_max_weight(max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
130+
fn batch_max_weight(&self, max_in_out_lengths: &Self::Stats, batch_size: usize) -> usize {
128131
let (max_input_length, max_output_length, _) = max_in_out_lengths;
129-
let max_seq_len = max_input_length + max_output_length;
130-
// Memory requirement roughly proportional to batch_size * seq_len^2
131-
batch_size * max_seq_len.pow(2)
132+
let seq_len_upper_bound = max_input_length + max_output_length;
133+
((seq_len_upper_bound * batch_size) as f64 * self.nexttoken_gradient) as usize
132134
}
133135

134-
fn batch_initial_weight((max_input_length, _, _): &Self::Stats, batch_size: usize) -> usize {
135-
batch_size * max_input_length.pow(2)
136+
fn batch_initial_weight(&self, (max_input_length, _, _): &Self::Stats, batch_size: usize) -> usize {
137+
((*max_input_length * batch_size) as f64 * self.nexttoken_gradient) as usize
136138
}
137139

138-
fn prefill_weight((max_input_length, _, _): &Self::Stats, batch_size: usize) -> usize {
140+
fn prefill_weight(&self, (max_input_length, _, _): &Self::Stats, batch_size: usize) -> usize {
139141
// Empirically, prefill latency is proportional to batch_size * seq_len^(3/2)
140-
batch_size * max_input_length.pow(3).sqrt()
142+
let input_tokens = batch_size * max_input_length;
143+
let quad_input_tokens = (input_tokens * max_input_length) as f64;
144+
let input_tokens = input_tokens as f64;
145+
let linear = input_tokens * self.prefill_linear_coef1;
146+
let quadratic = input_tokens * self.prefill_quadratic_coef1 +
147+
quad_input_tokens * self.prefill_quadratic_coef2;
148+
149+
f64::max(linear, quadratic) as usize
141150
}
142151

143152
fn percent_padding((max_input_length, _, total_in_tokens): &Self::Stats, batch_size: usize) -> f32 {
@@ -149,17 +158,18 @@ impl BatchType for PaddedBatch {
149158
}
150159

151160
fn exceeds_weight(
152-
tree: &BTreeSet<(usize, usize, usize)>, max_total_weight: usize, current_output_len: usize
161+
&self, tree: &BTreeSet<(usize, usize, usize)>, max_total_weight: usize, current_output_len: usize
153162
) -> bool {
163+
let total_weight_limit = max_total_weight as f64;
154164
let mut max_in_len = 0;
155165
// Work backwards from longest projected entry
156166
for (batch_size, (out_len, in_len, _)) in tree.iter().rev().enumerate() {
157167
let this_out_len = *out_len;
158168
max_in_len = max(max_in_len, *in_len);
159169
if this_out_len <= current_output_len {
160170
// Check if we breach max space for this segment
161-
let seq_len = max_in_len + this_out_len;
162-
if seq_len.pow(2) * (batch_size + 1) > max_total_weight {
171+
let seg_max_tokens = (max_in_len + this_out_len) * (batch_size + 1);
172+
if seg_max_tokens as f64 * self.nexttoken_gradient > total_weight_limit {
163173
return true
164174
}
165175
}
@@ -170,8 +180,4 @@ impl BatchType for PaddedBatch {
170180
fn count_tokens(input_lengths: impl Iterator<Item=usize>, batch_size: usize) -> usize {
171181
input_lengths.max().unwrap_or(0) * batch_size
172182
}
173-
174-
fn default_max_prefill_weight() -> usize {
175-
300000
176-
}
177183
}

router/src/main.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ struct Args {
1919
max_new_tokens: usize,
2020
#[clap(default_value = "12", long, env)]
2121
max_batch_size: usize,
22-
#[clap(default_value = None, long, env)]
23-
max_batch_weight: Option<usize>,
24-
#[clap(default_value = None, long, env)]
25-
max_prefill_weight: Option<usize>,
2622
#[clap(default_value = "0.2", long, env)]
2723
max_prefill_padding: f32,
2824
#[clap(default_value = "24", long, env)]
@@ -129,8 +125,6 @@ fn main() -> Result<(), std::io::Error> {
129125
max_sequence_length: args.max_sequence_length,
130126
max_new_tokens: args.max_new_tokens,
131127
max_batch_size: args.max_batch_size,
132-
max_batch_weight: args.max_batch_weight,
133-
max_prefill_weight: args.max_prefill_weight,
134128
max_prefill_padding: args.max_prefill_padding,
135129
max_waiting_tokens: args.max_waiting_tokens,
136130
client: sharded_client,

0 commit comments

Comments
 (0)