33
44use axum:: {
55 extract:: State ,
6- http:: { HeaderMap , StatusCode } ,
7- response:: Json ,
6+ http:: StatusCode ,
7+ response:: { sse:: Event , IntoResponse , Sse } ,
8+ Json ,
89} ;
10+ use futures_util:: stream;
911use rand:: Rng ;
1012use serde:: { Deserialize , Serialize } ;
1113use serde_json:: { json, Value } ;
14+ use std:: convert:: Infallible ;
1215use std:: time:: { SystemTime , UNIX_EPOCH } ;
1316
14- use crate :: models:: Usage ;
1517use crate :: server_state:: ServerState ;
1618
19+ #[ derive( Serialize , Debug ) ]
20+ pub struct Usage {
21+ pub prompt_tokens : u32 ,
22+ pub completion_tokens : u32 ,
23+ pub total_tokens : u32 ,
24+ }
25+
1726#[ derive( Deserialize ) ]
1827pub struct ChatCompletionRequest {
1928 pub messages : Option < Vec < Value > > ,
2029 pub model : Option < String > ,
30+ #[ serde( default ) ]
31+ pub stream : Option < bool > ,
2132 #[ serde( flatten) ]
2233 pub _other : Value ,
2334}
@@ -32,6 +43,33 @@ pub struct ChatCompletionResponse {
3243 pub usage : Usage ,
3344}
3445
46+ #[ derive( Serialize , Debug ) ]
47+ pub struct ChatCompletionChunk {
48+ pub id : String ,
49+ pub object : String ,
50+ pub created : u64 ,
51+ pub model : String ,
52+ pub choices : Vec < ChunkChoice > ,
53+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
54+ pub usage : Option < Usage > ,
55+ }
56+
57+ #[ derive( Serialize , Debug ) ]
58+ pub struct ChunkChoice {
59+ pub index : u32 ,
60+ pub delta : ChoiceDelta ,
61+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
62+ pub finish_reason : Option < String > ,
63+ }
64+
65+ #[ derive( Serialize , Debug , Default ) ]
66+ pub struct ChoiceDelta {
67+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
68+ pub role : Option < String > ,
69+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
70+ pub content : Option < String > ,
71+ }
72+
3573#[ derive( Serialize ) ]
3674pub struct Choice {
3775 pub index : u32 ,
@@ -48,7 +86,7 @@ pub struct Message {
4886pub async fn chat_completions (
4987 state : State < ServerState > ,
5088 Json ( payload) : Json < ChatCompletionRequest > ,
51- ) -> Result < ( HeaderMap , Json < Value > ) , ( StatusCode , HeaderMap , Json < Value > ) > {
89+ ) -> impl IntoResponse {
5290 if state. check_request_limit_exceeded ( ) {
5391 let headers = state. get_rate_limit_headers ( ) ;
5492 let error_body = json ! ( {
@@ -58,7 +96,7 @@ pub async fn chat_completions(
5896 "code" : "rate_limit_exceeded"
5997 }
6098 } ) ;
61- return Err ( ( StatusCode :: TOO_MANY_REQUESTS , headers, Json ( error_body) ) ) ;
99+ return ( StatusCode :: TOO_MANY_REQUESTS , headers, Json ( error_body) ) . into_response ( ) ;
62100 }
63101 state. increment_request_count ( ) ;
64102
@@ -75,14 +113,14 @@ pub async fn chat_completions(
75113 }
76114 } ) ;
77115
78- return Err ( ( status_code, headers, Json ( error_body) ) ) ;
116+ return ( status_code, headers, Json ( error_body) ) . into_response ( ) ;
79117 }
80118
81119 let response_length = state. get_response_length ( ) ;
82120
83121 if response_length == 0 {
84122 let headers = state. get_rate_limit_headers ( ) ;
85- return Err ( ( StatusCode :: NO_CONTENT , headers, Json ( json ! ( { } ) ) ) ) ;
123+ return ( StatusCode :: NO_CONTENT , headers, Json ( json ! ( { } ) ) ) . into_response ( ) ;
86124 }
87125
88126 let content = state. generate_lorem_content ( response_length) ;
@@ -106,10 +144,99 @@ pub async fn chat_completions(
106144 "code" : "rate_limit_exceeded"
107145 }
108146 } ) ;
109- return Err ( ( StatusCode :: TOO_MANY_REQUESTS , headers, Json ( error_body) ) ) ;
147+ return ( StatusCode :: TOO_MANY_REQUESTS , headers, Json ( error_body) ) . into_response ( ) ;
110148 }
111149 state. add_token_usage ( total_tokens) ;
112150
151+ let stream_response = payload. stream . unwrap_or ( false ) ;
152+ if stream_response {
153+ let id = format ! ( "chatcmpl-{}" , rand:: thread_rng( ) . gen :: <u32 >( ) ) ;
154+ let created = SystemTime :: now ( )
155+ . duration_since ( UNIX_EPOCH )
156+ . expect ( "should be able to get duration" )
157+ . as_secs ( ) ;
158+ let model = payload
159+ . model
160+ . clone ( )
161+ . unwrap_or_else ( || "gpt-3.5-turbo" . to_string ( ) ) ;
162+ let words = content
163+ . split_whitespace ( )
164+ . map ( |s| s. to_string ( ) )
165+ . collect :: < Vec < _ > > ( ) ;
166+
167+ let mut events = vec ! [ ] ;
168+
169+ // 1. First chunk with role
170+ let first_chunk = ChatCompletionChunk {
171+ id : id. clone ( ) ,
172+ object : "chat.completion.chunk" . to_string ( ) ,
173+ created,
174+ model : model. clone ( ) ,
175+ choices : vec ! [ ChunkChoice {
176+ index: 0 ,
177+ delta: ChoiceDelta {
178+ role: Some ( "assistant" . to_string( ) ) ,
179+ content: None ,
180+ } ,
181+ finish_reason: None ,
182+ } ] ,
183+ usage : None ,
184+ } ;
185+ events. push ( Ok :: < _ , Infallible > (
186+ Event :: default ( ) . data ( serde_json:: to_string ( & first_chunk) . unwrap ( ) ) ,
187+ ) ) ;
188+
189+ // 2. Content chunks
190+ for word in words {
191+ let chunk = ChatCompletionChunk {
192+ id : id. clone ( ) ,
193+ object : "chat.completion.chunk" . to_string ( ) ,
194+ created,
195+ model : model. clone ( ) ,
196+ choices : vec ! [ ChunkChoice {
197+ index: 0 ,
198+ delta: ChoiceDelta {
199+ role: None ,
200+ content: Some ( format!( "{} " , word) ) ,
201+ } ,
202+ finish_reason: None ,
203+ } ] ,
204+ usage : None ,
205+ } ;
206+ events. push ( Ok (
207+ Event :: default ( ) . data ( serde_json:: to_string ( & chunk) . unwrap ( ) )
208+ ) ) ;
209+ }
210+
211+ // 3. Final chunk with finish_reason
212+ let final_chunk = ChatCompletionChunk {
213+ id : id. clone ( ) ,
214+ object : "chat.completion.chunk" . to_string ( ) ,
215+ created,
216+ model : model. clone ( ) ,
217+ choices : vec ! [ ChunkChoice {
218+ index: 0 ,
219+ delta: Default :: default ( ) ,
220+ finish_reason: Some ( "stop" . to_string( ) ) ,
221+ } ] ,
222+ usage : Some ( Usage {
223+ prompt_tokens,
224+ completion_tokens,
225+ total_tokens,
226+ } ) ,
227+ } ;
228+ events. push ( Ok (
229+ Event :: default ( ) . data ( serde_json:: to_string ( & final_chunk) . unwrap ( ) )
230+ ) ) ;
231+
232+ // 4. Done message
233+ events. push ( Ok ( Event :: default ( ) . data ( "[DONE]" ) ) ) ;
234+
235+ let stream = stream:: iter ( events) ;
236+
237+ return Sse :: new ( stream) . into_response ( ) ;
238+ }
239+
113240 let response = ChatCompletionResponse {
114241 id : format ! ( "chatcmpl-{}" , rand:: thread_rng( ) . gen :: <u32 >( ) ) ,
115242 object : "chat.completion" . to_string ( ) ,
@@ -134,5 +261,5 @@ pub async fn chat_completions(
134261 } ;
135262
136263 let headers = state. get_rate_limit_headers ( ) ;
137- Ok ( ( headers, Json ( json ! ( response) ) ) )
264+ ( headers, Json ( json ! ( response) ) ) . into_response ( )
138265}
0 commit comments