Skip to content

Commit 12d9106

Browse files
authored
Fix OOM due to large prompt cache (#39)
This handles the OOM problem with large prefixes by both: - Taking the max prefix cache size into account when running the memory usage estimator, to ensure a full prefix cache does not cause an OOM - Taking the prefix length into consideration when deciding if a request will fit into a batch, to avoid large prefixes causing unexpected large memory allocations This includes an api breaking change to the config, as the prefix cache will not be enabled unless a user explicitly sets PREFIX_STORE_PATH to some non-empty value. Signed-off-by: Joe Runde <[email protected]>
1 parent c2eb7f7 commit 12d9106

File tree

11 files changed

+75
-33
lines changed

11 files changed

+75
-33
lines changed

router/src/batch_types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ pub(crate) trait BatchType: Send + Sync + Clone + 'static {
3232
let generated_count = entry.generated_tokens;
3333
Self::update_stats(
3434
&stats,
35-
entry.input_length + generated_count as usize,
35+
entry.input_length + entry.prefix_length + generated_count as usize,
3636
(entry.request.parameters.max_new_tokens - generated_count) as usize,
3737
)
3838
}

router/src/batcher.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use crate::pb::fmaas::StopReason::{
3939
Cancelled, EosToken, Error, MaxTokens, NotFinished, StopSequence, TimeLimit, TokenLimit
4040
};
4141
use crate::pb::fmaas::token_info::TopToken;
42+
use crate::validation::RequestSize;
4243

4344
/// Batcher
4445
#[derive(Clone)]
@@ -96,14 +97,15 @@ impl Batcher {
9697
pub(crate) async fn infer(
9798
&self,
9899
input_length: usize,
100+
prefix_length: usize,
99101
request: GenerateRequest,
100102
) -> Result<InferResponse, InferError> {
101103
// One shot channel to communicate with the background batching task
102104
let (response_tx, response_rx) = oneshot::channel();
103105

104106
// Try to add the request to the queue
105107
self.enqueue_request(vec![
106-
Entry::new(request, input_length, Some(response_tx), None),
108+
Entry::new(request, input_length, prefix_length, Some(response_tx), None),
107109
])?;
108110

109111
// Await on the response from the background task
@@ -117,14 +119,14 @@ impl Batcher {
117119
// Add a batch of new requests to the queue and return an vec of futures that will generate the text
118120
pub(crate) async fn infer_batch(
119121
&self,
120-
requests: Vec<(usize, GenerateRequest)>,
122+
requests: Vec<(RequestSize, GenerateRequest)>,
121123
) -> Result<Vec<Map<Receiver<Result<InferResponse, ClientError>>,
122124
impl FnOnce(Result<Result<InferResponse, ClientError>, RecvError>) -> Result<InferResponse, InferError> + '_>>, InferError> {
123125

124126
let mut response_chans= vec![];
125127

126128
let entries: Vec<Entry> = requests.into_iter()
127-
.map(|(input_length, request)| {
129+
.map(|(request_size, request)| {
128130
// One shot channel to communicate with the background batching task
129131
let (response_tx, response_rx) = oneshot::channel();
130132
response_chans.push(response_rx
@@ -134,7 +136,7 @@ impl Batcher {
134136
})
135137
);
136138

137-
Entry::new(request, input_length, Some(response_tx), None)
139+
Entry::new(request, request_size.input_length, request_size.prefix_length, Some(response_tx), None)
138140
}).collect();
139141

140142
// Try to add the request to the queue
@@ -147,6 +149,7 @@ impl Batcher {
147149
pub(crate) async fn infer_stream<T, C>(
148150
&self,
149151
input_length: usize,
152+
prefix_length: usize,
150153
request: GenerateRequest,
151154
result_map: fn (Result<InferResponse, InferError>) -> T,
152155
on_drop: fn (&C, u32, StopReason, Option<u64>, Option<Times>, String, Option<InferError>),
@@ -170,7 +173,7 @@ impl Batcher {
170173

171174
// Try to add the request to the queue
172175
self.enqueue_request(vec![
173-
Entry::new(request, input_length, None, Some(response_tx)),
176+
Entry::new(request, input_length, prefix_length, None, Some(response_tx)),
174177
])?;
175178

176179
Ok(ResponseStream {

router/src/grpc_server.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::server::ServerState;
2525
use unicode_truncate::UnicodeTruncateStr;
2626
use crate::pb::fmaas::model_info_response::ModelKind;
2727
use crate::tokenizer::AsyncTokenizer;
28-
use crate::validation::ValidationError;
28+
use crate::validation::{RequestSize, ValidationError};
2929

3030
/// Whether to fail if sampling parameters are provided in greedy-mode requests
3131
/// or to silently ignore them.
@@ -127,18 +127,18 @@ impl GenerationService for GenerationServicer {
127127

128128
if batch_size == 1 {
129129
// Single request case
130-
let (input_length, request) = valids.into_iter().next().unwrap();
131-
self.state.batcher.infer(input_length, request)
130+
let (request_size, request) = valids.into_iter().next().unwrap();
131+
self.state.batcher.infer(request_size.input_length, request_size.prefix_length, request)
132132
.map_ok(|response| {
133133
log_response(
134-
&response.times, input_length, response.gen_token_count, response.reason,
134+
&response.times, request_size.input_length, response.gen_token_count, response.reason,
135135
&response.output_text, start_time, "single", "Request", response.request_id
136136
);
137137
vec![response.into()]
138138
}).await
139139
} else {
140140
// Batch size > 1
141-
let input_tokens = valids.iter().map(|r| r.0).collect::<Vec<usize>>();
141+
let input_tokens = valids.iter().map(|r| r.0.input_length).collect::<Vec<usize>>();
142142
match self.state.batcher.infer_batch(valids).await {
143143
Ok(response_chans) => {
144144
try_join_all(response_chans.into_iter().zip(input_tokens).enumerate()
@@ -198,13 +198,13 @@ impl GenerationService for GenerationServicer {
198198
)?;
199199

200200
// Validate request
201-
let (input_length, validated_request) = self
201+
let (request_size, validated_request) = self
202202
.validate(sr.prefix_id, sr.params, vec![req.text], start_time)
203203
.await?
204204
.pop().unwrap();
205205

206206
let stream = self.state.batcher
207-
.infer_stream(input_length, validated_request, |r| match r {
207+
.infer_stream(request_size.input_length, request_size.prefix_length, validated_request, |r| match r {
208208
Ok(resp) => Ok(resp.into()),
209209
Err(err) => Err(Status::from_error(Box::new(err))),
210210
}, |ctx, count, reason, request_id, times, out, err| {
@@ -222,7 +222,7 @@ impl GenerationService for GenerationServicer {
222222
}
223223
}, StreamContext {
224224
span: Span::current(),
225-
input_token_count: input_length,
225+
input_token_count: request_size.input_length,
226226
start_time,
227227
_permit: permit,
228228
})
@@ -297,7 +297,7 @@ impl GenerationServicer {
297297
parameters: Option<Parameters>,
298298
inputs: Vec<String>,
299299
start_time: Instant,
300-
) -> Result<Vec<(usize, GenerateRequest)>, Status> {
300+
) -> Result<Vec<(RequestSize, GenerateRequest)>, Status> {
301301
match convert_params(parameters, self.state.default_include_stop_seqs) {
302302
Ok(params) => self.state.validation.validate(
303303
prefix_id, params, inputs

router/src/queue.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ pub(crate) struct Entry {
3232
pub stream_tx: Option<UnboundedSender<Result<InferResponse, ClientError>>>,
3333
/// Number of tokens in the input
3434
pub input_length: usize,
35+
/// Number of virtual tokens in the prefix, if one is specified
36+
pub prefix_length: usize,
3537
/// Instant when this entry was queued
3638
pub queue_time: Instant,
3739
/// Instant when this entry was added to a batch (queue end time)
@@ -52,6 +54,7 @@ impl Entry {
5254
pub(crate) fn new(
5355
request: GenerateRequest,
5456
input_length: usize,
57+
prefix_length: usize,
5558
response_tx: Option<Sender<Result<InferResponse, ClientError>>>,
5659
stream_tx: Option<UnboundedSender<Result<InferResponse, ClientError>>>,
5760
) -> Self {
@@ -60,6 +63,7 @@ impl Entry {
6063
response_tx,
6164
stream_tx,
6265
input_length,
66+
prefix_length,
6367
input_tokens: vec![],
6468
queue_time: Instant::now(),
6569
batch_time: None,
@@ -265,7 +269,9 @@ impl<B: BatchType> Queue<B> {
265269
break
266270
}
267271

268-
let input_len = entry.input_length;
272+
// For the purposes of deciding if a request can fit into a batch,
273+
// the input length needs to take the length of the prefix into account as well
274+
let input_len = entry.input_length + entry.prefix_length;
269275
let output_len = entry.request.parameters.max_new_tokens as usize;
270276
let next_stats = <B>::update_stats(
271277
&batch_stats, input_len, output_len
@@ -289,7 +295,7 @@ impl<B: BatchType> Queue<B> {
289295
let generated_count = e.generated_tokens as usize;
290296
t.insert((
291297
e.request.parameters.max_new_tokens as usize - generated_count,
292-
e.input_length + generated_count,
298+
e.input_length + e.prefix_length + generated_count,
293299
t.len(),
294300
));
295301
}

router/src/server.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async fn generate(
9191
// Validate request
9292
//let details = req.0.parameters.details;
9393
let GenerateRequest {inputs, prefix_id, parameters} = req.0;
94-
let (input_length, validated_request) =
94+
let (request_size, validated_request) =
9595
state.validation.validate(
9696
prefix_id, parameters, vec![inputs]
9797
).await.map_err(|err| {
@@ -102,7 +102,7 @@ async fn generate(
102102
// Inference
103103
let response = state
104104
.batcher
105-
.infer(input_length, validated_request)
105+
.infer(request_size.input_length, request_size.prefix_length, validated_request)
106106
.await
107107
.map_err(|err| {
108108
tracing::error!("{err}");

router/src/validation.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ pub struct Validation {
2626
prefix_cache: Cache<String, usize, RandomState>,
2727
}
2828

29+
pub struct RequestSize {
30+
pub(crate) input_length: usize,
31+
pub(crate) prefix_length: usize
32+
}
33+
2934
impl Validation {
3035
pub(crate) fn new(
3136
tokenizer: AsyncTokenizer,
@@ -55,7 +60,7 @@ impl Validation {
5560
prefix_id: Option<String>,
5661
params: GenerateParameters,
5762
inputs: Vec<String>,
58-
) -> Result<Vec<(usize, GenerateRequest)>, ValidationError> {
63+
) -> Result<Vec<(RequestSize, GenerateRequest)>, ValidationError> {
5964
let min_new_tokens = params.min_new_tokens as usize;
6065
let max_new_tokens = params.max_new_tokens as usize;
6166

@@ -165,18 +170,21 @@ impl Validation {
165170
}
166171

167172
Ok((
168-
input_length,
173+
RequestSize {
174+
input_length,
175+
prefix_length
176+
},
169177
GenerateRequest {
170178
prefix_id: prefix_id.clone(),
171179
inputs: input,
172180
parameters,
173181
}
174182
))
175183
}
176-
}).collect::<Result<Vec<(usize, GenerateRequest)>, ValidationError>>().map(|results| {
184+
}).collect::<Result<Vec<(RequestSize, GenerateRequest)>, ValidationError>>().map(|results| {
177185
// Only record these for successful validation
178-
for (input_length, _) in &results {
179-
metrics::histogram!("tgi_request_input_length", *input_length as f64);
186+
for (request_size, _) in &results {
187+
metrics::histogram!("tgi_request_input_length", request_size.input_length as f64);
180188
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
181189
}
182190
results

server/tests/test_prompt_cache.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,31 @@
2323
INTEGRATION_TESTS_DIR = os.path.join(REPO_ROOT, "integration_tests")
2424

2525

26+
@pytest.fixture(autouse=True)
27+
def temp_prompt_store(tmp_path):
28+
# Unless overriden by another fixture, sets the prefix store path to some temp dir
29+
with patch("text_generation_server.prompt_cache.PREFIX_STORE_PATH", tmp_path):
30+
yield
31+
32+
2633
@pytest.fixture()
27-
def temp_prompt_store():
34+
def integration_test_prompts():
2835
with patch("text_generation_server.prompt_cache.PREFIX_STORE_PATH", Path(os.path.join(INTEGRATION_TESTS_DIR, "prompt_prefixes"))):
2936
yield
3037

3138

3239
@pytest.fixture()
33-
def tiny_starcoder_decoder_prompt(temp_prompt_store):
40+
def tiny_starcoder_decoder_prompt(integration_test_prompts):
3441
return "tiny_starcoder"
3542

3643

3744
@pytest.fixture()
38-
def tiny_raw_llama_peft_adapter_prompt(temp_prompt_store):
45+
def tiny_raw_llama_peft_adapter_prompt(integration_test_prompts):
3946
return "tinyllama_peft_adapter_raw"
4047

4148

4249
@pytest.fixture()
43-
def tiny_llama_peft_adapter_prompt(temp_prompt_store):
50+
def tiny_llama_peft_adapter_prompt(integration_test_prompts):
4451
return "tinyllama_peft_adapter"
4552

4653

server/text_generation_server/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def serve(
2525
max_sequence_length: int = 2048,
2626
max_new_tokens: int = 1024,
2727
max_batch_size: int = 12,
28-
batch_safety_margin: int = 20,
28+
batch_safety_margin: int = typer.Option(20, help="Integer from 0-100, a percentage of free GPU memory to hold back as a safety margin to avoid OOM"),
2929
revision: Optional[str] = None,
3030
sharded: bool = False,
3131
cuda_process_memory_fraction: float = 1.0,

server/text_generation_server/models/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from transformers import PreTrainedModel
1212

13+
import text_generation_server.prompt_cache
1314
from text_generation_server.models.types import Batch, GenerateError
1415
from text_generation_server.inference_engine.engine import BaseInferenceEngine
1516
from text_generation_server.pb import generate_pb2
@@ -44,7 +45,8 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype, max_seq_leng
4445
# Check whether model supports position_ids
4546
self.use_position_ids = "position_ids" in inspect.signature(self.model.forward).parameters
4647

47-
prompt_prefix_supported = self._setup_prompt_encoder()
48+
# Short-circuit: Don't set up the prompt encoder if the prompt cache is not set
49+
prompt_prefix_supported = self.prompt_cache_set() and self._setup_prompt_encoder()
4850

4951
if prompt_prefix_supported:
5052
# Set up prefix cache
@@ -184,6 +186,10 @@ def get_indices_to_keep(
184186
next_batch_keep_indices.append(i)
185187
return next_batch_keep_indices
186188

189+
@staticmethod
190+
def prompt_cache_set() -> bool:
191+
return text_generation_server.prompt_cache.PREFIX_STORE_PATH is not None
192+
187193
def _setup_prompt_encoder(self) -> bool:
188194
try:
189195
self.word_embeddings = self.model.get_input_embeddings()

server/text_generation_server/prompt_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
PREFIX_STORE_PATH = Path(os.getenv("PREFIX_STORE_PATH", "prompt_prefixes"))
12+
_PREFIX_STORE_PATH_STR = os.getenv("PREFIX_STORE_PATH", None)
13+
PREFIX_STORE_PATH = Path(_PREFIX_STORE_PATH_STR) if _PREFIX_STORE_PATH_STR else None
1314

1415
VALID_PREFIX_ID_PATTERN = re.compile("[/\\w\\-]+")
1516
PROMPT_CACHE_SIZE_MB = int(os.getenv("PROMPT_CACHE_SIZE_MB", "512"))

0 commit comments

Comments
 (0)