Skip to content

Commit a90d443

Browse files
authored
OpenAI v1 Chat Completions API (#171)
1 parent 82dac66 commit a90d443

File tree

24 files changed

+219
-27
lines changed

24 files changed

+219
-27
lines changed

docs/reference/openapi.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,11 @@
607607
"api_token": {
608608
"type": "string",
609609
"nullable": true
610+
},
611+
"apply_chat_template": {
612+
"type": "boolean",
613+
"default": "false",
614+
"example": true
610615
}
611616
}
612617
},

proto/generate.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ message Request {
102102
bool prefill_logprobs = 6;
103103
/// Adapter index
104104
uint32 adapter_index = 7;
105+
/// Apply chat template to inputs
106+
bool apply_chat_template = 8;
105107
}
106108

107109
message Batch {

router/client/src/client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ impl Client {
134134
}),
135135
adapter_index: 0,
136136
prefill_logprobs: true,
137+
apply_chat_template: false,
137138
});
138139
n_tokens += max_input_length;
139140
}

router/src/health.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ impl Health {
5252
ignore_eos_token: false,
5353
}),
5454
adapter_index: 0,
55+
apply_chat_template: false,
5556
};
5657
let batch = Batch {
5758
id: BATCH_ID,

router/src/lib.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ pub(crate) struct GenerateParameters {
145145
#[schema(default = "true")]
146146
pub decoder_input_details: bool,
147147
#[serde(default)]
148+
#[schema(default = "false")]
149+
pub apply_chat_template: bool,
150+
#[serde(default)]
148151
#[schema(
149152
exclusive_minimum = 0,
150153
nullable = true,
@@ -177,6 +180,7 @@ fn default_parameters() -> GenerateParameters {
177180
watermark: false,
178181
details: false,
179182
decoder_input_details: false,
183+
apply_chat_template: false,
180184
seed: None,
181185
}
182186
}
@@ -320,7 +324,7 @@ struct UsageInfo {
320324
#[derive(Clone, Debug, Deserialize, ToSchema)]
321325
struct ChatCompletionRequest {
322326
model: String,
323-
messages: Vec<String>,
327+
messages: Vec<std::collections::HashMap<String, String>>,
324328
temperature: Option<f32>,
325329
top_p: Option<f32>,
326330
n: Option<i32>,
@@ -451,6 +455,40 @@ impl From<CompletionRequest> for CompatGenerateRequest {
451455
watermark: false,
452456
details: true,
453457
decoder_input_details: req.logprobs.is_some(),
458+
apply_chat_template: false,
459+
seed: None,
460+
},
461+
stream: req.stream.unwrap_or(false),
462+
}
463+
}
464+
}
465+
466+
impl From<ChatCompletionRequest> for CompatGenerateRequest {
467+
fn from(req: ChatCompletionRequest) -> Self {
468+
CompatGenerateRequest {
469+
inputs: serde_json::to_string(&req.messages).unwrap(),
470+
parameters: GenerateParameters {
471+
adapter_id: req.model.parse().ok(),
472+
adapter_source: None,
473+
api_token: None,
474+
best_of: req.n.map(|x| x as usize),
475+
temperature: req.temperature,
476+
repetition_penalty: None,
477+
top_k: None,
478+
top_p: req.top_p,
479+
typical_p: None,
480+
do_sample: !req.n.is_none(),
481+
max_new_tokens: req
482+
.max_tokens
483+
.map(|x| x as u32)
484+
.unwrap_or(default_max_new_tokens()),
485+
return_full_text: None,
486+
stop: req.stop,
487+
truncate: None,
488+
watermark: false,
489+
details: true,
490+
decoder_input_details: false,
491+
apply_chat_template: true,
454492
seed: None,
455493
},
456494
stream: req.stream.unwrap_or(false),

router/src/scheduler.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ impl AdapterSchedulerState {
334334
parameters: Some(entry.request.parameters.clone()),
335335
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
336336
adapter_index: adapter.index(),
337+
apply_chat_template: entry.request.apply_chat_template,
337338
});
338339
// Set batch_time
339340
entry.batch_time = Some(Instant::now());

router/src/server.rs

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ use crate::health::Health;
33
use crate::infer::{InferError, InferResponse, InferStreamResponse};
44
use crate::validation::ValidationError;
55
use crate::{
6-
BestOfSequence, CompatGenerateRequest, CompletionRequest, CompletionResponse,
7-
CompletionStreamResponse, Details, ErrorResponse, FinishReason, GenerateParameters,
8-
GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, StreamDetails,
9-
StreamResponse, Token, Validation,
6+
BestOfSequence, ChatCompletionRequest, CompatGenerateRequest, CompletionRequest,
7+
CompletionResponse, CompletionStreamResponse, Details, ErrorResponse, FinishReason,
8+
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
9+
StreamDetails, StreamResponse, Token, Validation,
1010
};
1111
use axum::extract::Extension;
1212
use axum::http::{HeaderMap, Method, StatusCode};
@@ -78,7 +78,7 @@ async fn compat_generate(
7878
}
7979
}
8080

81-
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
81+
/// OpenAI compatible completions endpoint
8282
#[utoipa::path(
8383
post,
8484
tag = "LoRAX",
@@ -138,6 +138,66 @@ async fn completions_v1(
138138
}
139139
}
140140

141+
/// OpenAI compatible chat completions endpoint
142+
#[utoipa::path(
143+
post,
144+
tag = "LoRAX",
145+
path = "/v1/chat/completions",
146+
request_body = ChatCompletionRequest,
147+
responses(
148+
(status = 200, description = "Generated Text",
149+
content(
150+
("application/json" = ChatCompletionResponse),
151+
("text/event-stream" = ChatCompletionStreamResponse),
152+
)),
153+
(status = 424, description = "Generation Error", body = ErrorResponse,
154+
example = json ! ({"error": "Request failed during generation"})),
155+
(status = 429, description = "Model is overloaded", body = ErrorResponse,
156+
example = json ! ({"error": "Model is overloaded"})),
157+
(status = 422, description = "Input validation error", body = ErrorResponse,
158+
example = json ! ({"error": "Input validation error"})),
159+
(status = 500, description = "Incomplete generation", body = ErrorResponse,
160+
example = json ! ({"error": "Incomplete generation"})),
161+
)
162+
)]
163+
#[instrument(skip(infer, req))]
164+
async fn chat_completions_v1(
165+
default_return_full_text: Extension<bool>,
166+
infer: Extension<Infer>,
167+
req: Json<ChatCompletionRequest>,
168+
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
169+
let req = req.0;
170+
let mut gen_req = CompatGenerateRequest::from(req);
171+
172+
// default return_full_text given the pipeline_tag
173+
if gen_req.parameters.return_full_text.is_none() {
174+
gen_req.parameters.return_full_text = Some(default_return_full_text.0)
175+
}
176+
177+
// switch on stream
178+
if gen_req.stream {
179+
let callback = move |resp: StreamResponse| {
180+
Event::default()
181+
.json_data(CompletionStreamResponse::from(resp))
182+
.map_or_else(
183+
|err| {
184+
tracing::error!("Failed to serialize CompletionStreamResponse: {err}");
185+
Event::default()
186+
},
187+
|data| data,
188+
)
189+
};
190+
191+
let (headers, stream) =
192+
generate_stream_with_callback(infer, Json(gen_req.into()), callback).await;
193+
Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response())
194+
} else {
195+
let (headers, generation) = generate(infer, Json(gen_req.into())).await?;
196+
// wrap generation inside a Vec to match api-inference
197+
Ok((headers, Json(vec![CompletionResponse::from(generation.0)])).into_response())
198+
}
199+
}
200+
141201
/// LoRAX endpoint info
142202
#[utoipa::path(
143203
get,
@@ -771,6 +831,7 @@ pub async fn run(
771831
.route("/generate", post(generate))
772832
.route("/generate_stream", post(generate_stream))
773833
.route("/v1/completions", post(completions_v1))
834+
.route("/v1/chat/completions", post(chat_completions_v1))
774835
// AWS Sagemaker route
775836
.route("/invocations", post(compat_generate))
776837
// Base Health route

router/src/validation.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ impl Validation {
145145
watermark,
146146
adapter_id,
147147
decoder_input_details,
148+
apply_chat_template,
148149
..
149150
} = request.parameters;
150151

@@ -270,6 +271,7 @@ impl Validation {
270271
parameters,
271272
stopping_parameters,
272273
adapter,
274+
apply_chat_template,
273275
})
274276
}
275277

@@ -344,6 +346,7 @@ pub(crate) struct ValidGenerateRequest {
344346
pub parameters: NextTokenChooserParameters,
345347
pub stopping_parameters: StoppingCriteriaParameters,
346348
pub adapter: Adapter,
349+
pub apply_chat_template: bool,
347350
}
348351

349352
#[derive(Error, Debug)]

server/lorax_server/models/bloom.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
weight_files,
2121
Weights,
2222
)
23+
from lorax_server.utils.tokenizer import TokenizerManager
2324

2425

2526
class BloomCausalLMBatch(CausalLMBatch):
@@ -28,10 +29,11 @@ def from_pb(
2829
cls,
2930
pb: generate_pb2.Batch,
3031
tokenizer: PreTrainedTokenizerBase,
32+
tokenizers: TokenizerManager,
3133
dtype: torch.dtype,
3234
device: torch.device,
3335
) -> "CausalLMBatch":
34-
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
36+
batch = super().from_pb(pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device)
3537
batch.keys_head_dim_last = False
3638
return batch
3739

server/lorax_server/models/causal_lm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import torch
23
import inspect
34

@@ -15,6 +16,7 @@
1516
)
1617
from lorax_server.pb import generate_pb2
1718
from lorax_server.utils import NextTokenChooser, StoppingCriteria, Sampling
19+
from lorax_server.utils.tokenizer import TokenizerManager
1820

1921
tracer = trace.get_tracer(__name__)
2022

@@ -69,6 +71,7 @@ def from_pb(
6971
cls,
7072
pb: generate_pb2.Batch,
7173
tokenizer: PreTrainedTokenizerBase,
74+
tokenizers: TokenizerManager,
7275
dtype: torch.dtype,
7376
device: torch.device,
7477
) -> "CausalLMBatch":
@@ -86,7 +89,8 @@ def from_pb(
8689
adapter_indices_list = []
8790
for i, r in enumerate(pb.requests):
8891
requests_idx_mapping[r.id] = i
89-
inputs.append(r.inputs)
92+
req_inputs = tokenizers.get_inputs(r, tokenizer)
93+
inputs.append(req_inputs)
9094
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
9195
stopping_criteria = StoppingCriteria.from_pb(
9296
r.stopping_parameters, tokenizer

0 commit comments

Comments
 (0)