@@ -3,10 +3,10 @@ use crate::health::Health;
33use crate :: infer:: { InferError , InferResponse , InferStreamResponse } ;
44use crate :: validation:: ValidationError ;
55use crate :: {
6- BestOfSequence , CompatGenerateRequest , CompletionRequest , CompletionResponse ,
7- CompletionStreamResponse , Details , ErrorResponse , FinishReason , GenerateParameters ,
8- GenerateRequest , GenerateResponse , HubModelInfo , Infer , Info , PrefillToken , StreamDetails ,
9- StreamResponse , Token , Validation ,
6+ BestOfSequence , ChatCompletionRequest , CompatGenerateRequest , CompletionRequest ,
7+ CompletionResponse , CompletionStreamResponse , Details , ErrorResponse , FinishReason ,
8+ GenerateParameters , GenerateRequest , GenerateResponse , HubModelInfo , Infer , Info , PrefillToken ,
9+ StreamDetails , StreamResponse , Token , Validation ,
1010} ;
1111use axum:: extract:: Extension ;
1212use axum:: http:: { HeaderMap , Method , StatusCode } ;
@@ -78,7 +78,7 @@ async fn compat_generate(
7878 }
7979}
8080
81- /// Generate tokens if `stream == false` or a stream of token if `stream == true`
81+ /// OpenAI compatible completions endpoint
8282#[ utoipa:: path(
8383post,
8484tag = "LoRAX" ,
@@ -138,6 +138,66 @@ async fn completions_v1(
138138 }
139139}
140140
141+ /// OpenAI compatible chat completions endpoint
142+ #[ utoipa:: path(
143+ post,
144+ tag = "LoRAX" ,
145+ path = "/v1/chat/completions" ,
146+ request_body = ChatCompletionRequest ,
147+ responses(
148+ ( status = 200 , description = "Generated Text" ,
149+ content(
150+ ( "application/json" = ChatCompletionResponse ) ,
151+ ( "text/event-stream" = ChatCompletionStreamResponse ) ,
152+ ) ) ,
153+ ( status = 424 , description = "Generation Error" , body = ErrorResponse ,
154+ example = json ! ( { "error" : "Request failed during generation" } ) ) ,
155+ ( status = 429 , description = "Model is overloaded" , body = ErrorResponse ,
156+ example = json ! ( { "error" : "Model is overloaded" } ) ) ,
157+ ( status = 422 , description = "Input validation error" , body = ErrorResponse ,
158+ example = json ! ( { "error" : "Input validation error" } ) ) ,
159+ ( status = 500 , description = "Incomplete generation" , body = ErrorResponse ,
160+ example = json ! ( { "error" : "Incomplete generation" } ) ) ,
161+ )
162+ ) ]
163+ #[ instrument( skip( infer, req) ) ]
164+ async fn chat_completions_v1 (
165+ default_return_full_text : Extension < bool > ,
166+ infer : Extension < Infer > ,
167+ req : Json < ChatCompletionRequest > ,
168+ ) -> Result < Response , ( StatusCode , Json < ErrorResponse > ) > {
169+ let req = req. 0 ;
170+ let mut gen_req = CompatGenerateRequest :: from ( req) ;
171+
172+ // default return_full_text given the pipeline_tag
173+ if gen_req. parameters . return_full_text . is_none ( ) {
174+ gen_req. parameters . return_full_text = Some ( default_return_full_text. 0 )
175+ }
176+
177+ // switch on stream
178+ if gen_req. stream {
179+ let callback = move |resp : StreamResponse | {
180+ Event :: default ( )
181+ . json_data ( CompletionStreamResponse :: from ( resp) )
182+ . map_or_else (
183+ |err| {
184+ tracing:: error!( "Failed to serialize CompletionStreamResponse: {err}" ) ;
185+ Event :: default ( )
186+ } ,
187+ |data| data,
188+ )
189+ } ;
190+
191+ let ( headers, stream) =
192+ generate_stream_with_callback ( infer, Json ( gen_req. into ( ) ) , callback) . await ;
193+ Ok ( ( headers, Sse :: new ( stream) . keep_alive ( KeepAlive :: default ( ) ) ) . into_response ( ) )
194+ } else {
195+ let ( headers, generation) = generate ( infer, Json ( gen_req. into ( ) ) ) . await ?;
196+ // wrap generation inside a Vec to match api-inference
197+ Ok ( ( headers, Json ( vec ! [ CompletionResponse :: from( generation. 0 ) ] ) ) . into_response ( ) )
198+ }
199+ }
200+
141201/// LoRAX endpoint info
142202#[ utoipa:: path(
143203get,
@@ -771,6 +831,7 @@ pub async fn run(
771831 . route ( "/generate" , post ( generate) )
772832 . route ( "/generate_stream" , post ( generate_stream) )
773833 . route ( "/v1/completions" , post ( completions_v1) )
834+ . route ( "/v1/chat/completions" , post ( chat_completions_v1) )
774835 // AWS Sagemaker route
775836 . route ( "/invocations" , post ( compat_generate) )
776837 // Base Health route
0 commit comments