Skip to content

Commit 63f34a2

Browse files
authored
add completion streaming (#6)
1 parent 4bef3d8 commit 63f34a2

File tree

3 files changed

+136
-21
lines changed

3 files changed

+136
-21
lines changed

src/chat_completions.rs

Lines changed: 136 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,32 @@
33

44
use axum::{
55
extract::State,
6-
http::{HeaderMap, StatusCode},
7-
response::Json,
6+
http::StatusCode,
7+
response::{sse::Event, IntoResponse, Sse},
8+
Json,
89
};
10+
use futures_util::stream;
911
use rand::Rng;
1012
use serde::{Deserialize, Serialize};
1113
use serde_json::{json, Value};
14+
use std::convert::Infallible;
1215
use std::time::{SystemTime, UNIX_EPOCH};
1316

14-
use crate::models::Usage;
1517
use crate::server_state::ServerState;
1618

19+
#[derive(Serialize, Debug)]
20+
pub struct Usage {
21+
pub prompt_tokens: u32,
22+
pub completion_tokens: u32,
23+
pub total_tokens: u32,
24+
}
25+
1726
#[derive(Deserialize)]
1827
pub struct ChatCompletionRequest {
1928
pub messages: Option<Vec<Value>>,
2029
pub model: Option<String>,
30+
#[serde(default)]
31+
pub stream: Option<bool>,
2132
#[serde(flatten)]
2233
pub _other: Value,
2334
}
@@ -32,6 +43,33 @@ pub struct ChatCompletionResponse {
3243
pub usage: Usage,
3344
}
3445

46+
#[derive(Serialize, Debug)]
47+
pub struct ChatCompletionChunk {
48+
pub id: String,
49+
pub object: String,
50+
pub created: u64,
51+
pub model: String,
52+
pub choices: Vec<ChunkChoice>,
53+
#[serde(skip_serializing_if = "Option::is_none")]
54+
pub usage: Option<Usage>,
55+
}
56+
57+
#[derive(Serialize, Debug)]
58+
pub struct ChunkChoice {
59+
pub index: u32,
60+
pub delta: ChoiceDelta,
61+
#[serde(skip_serializing_if = "Option::is_none")]
62+
pub finish_reason: Option<String>,
63+
}
64+
65+
#[derive(Serialize, Debug, Default)]
66+
pub struct ChoiceDelta {
67+
#[serde(skip_serializing_if = "Option::is_none")]
68+
pub role: Option<String>,
69+
#[serde(skip_serializing_if = "Option::is_none")]
70+
pub content: Option<String>,
71+
}
72+
3573
#[derive(Serialize)]
3674
pub struct Choice {
3775
pub index: u32,
@@ -48,7 +86,7 @@ pub struct Message {
4886
pub async fn chat_completions(
4987
state: State<ServerState>,
5088
Json(payload): Json<ChatCompletionRequest>,
51-
) -> Result<(HeaderMap, Json<Value>), (StatusCode, HeaderMap, Json<Value>)> {
89+
) -> impl IntoResponse {
5290
if state.check_request_limit_exceeded() {
5391
let headers = state.get_rate_limit_headers();
5492
let error_body = json!({
@@ -58,7 +96,7 @@ pub async fn chat_completions(
5896
"code": "rate_limit_exceeded"
5997
}
6098
});
61-
return Err((StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)));
99+
return (StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)).into_response();
62100
}
63101
state.increment_request_count();
64102

@@ -75,14 +113,14 @@ pub async fn chat_completions(
75113
}
76114
});
77115

78-
return Err((status_code, headers, Json(error_body)));
116+
return (status_code, headers, Json(error_body)).into_response();
79117
}
80118

81119
let response_length = state.get_response_length();
82120

83121
if response_length == 0 {
84122
let headers = state.get_rate_limit_headers();
85-
return Err((StatusCode::NO_CONTENT, headers, Json(json!({}))));
123+
return (StatusCode::NO_CONTENT, headers, Json(json!({}))).into_response();
86124
}
87125

88126
let content = state.generate_lorem_content(response_length);
@@ -106,10 +144,99 @@ pub async fn chat_completions(
106144
"code": "rate_limit_exceeded"
107145
}
108146
});
109-
return Err((StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)));
147+
return (StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)).into_response();
110148
}
111149
state.add_token_usage(total_tokens);
112150

151+
let stream_response = payload.stream.unwrap_or(false);
152+
if stream_response {
153+
let id = format!("chatcmpl-{}", rand::thread_rng().gen::<u32>());
154+
let created = SystemTime::now()
155+
.duration_since(UNIX_EPOCH)
156+
.expect("should be able to get duration")
157+
.as_secs();
158+
let model = payload
159+
.model
160+
.clone()
161+
.unwrap_or_else(|| "gpt-3.5-turbo".to_string());
162+
let words = content
163+
.split_whitespace()
164+
.map(|s| s.to_string())
165+
.collect::<Vec<_>>();
166+
167+
let mut events = vec![];
168+
169+
// 1. First chunk with role
170+
let first_chunk = ChatCompletionChunk {
171+
id: id.clone(),
172+
object: "chat.completion.chunk".to_string(),
173+
created,
174+
model: model.clone(),
175+
choices: vec![ChunkChoice {
176+
index: 0,
177+
delta: ChoiceDelta {
178+
role: Some("assistant".to_string()),
179+
content: None,
180+
},
181+
finish_reason: None,
182+
}],
183+
usage: None,
184+
};
185+
events.push(Ok::<_, Infallible>(
186+
Event::default().data(serde_json::to_string(&first_chunk).unwrap()),
187+
));
188+
189+
// 2. Content chunks
190+
for word in words {
191+
let chunk = ChatCompletionChunk {
192+
id: id.clone(),
193+
object: "chat.completion.chunk".to_string(),
194+
created,
195+
model: model.clone(),
196+
choices: vec![ChunkChoice {
197+
index: 0,
198+
delta: ChoiceDelta {
199+
role: None,
200+
content: Some(format!("{} ", word)),
201+
},
202+
finish_reason: None,
203+
}],
204+
usage: None,
205+
};
206+
events.push(Ok(
207+
Event::default().data(serde_json::to_string(&chunk).unwrap())
208+
));
209+
}
210+
211+
// 3. Final chunk with finish_reason
212+
let final_chunk = ChatCompletionChunk {
213+
id: id.clone(),
214+
object: "chat.completion.chunk".to_string(),
215+
created,
216+
model: model.clone(),
217+
choices: vec![ChunkChoice {
218+
index: 0,
219+
delta: Default::default(),
220+
finish_reason: Some("stop".to_string()),
221+
}],
222+
usage: Some(Usage {
223+
prompt_tokens,
224+
completion_tokens,
225+
total_tokens,
226+
}),
227+
};
228+
events.push(Ok(
229+
Event::default().data(serde_json::to_string(&final_chunk).unwrap())
230+
));
231+
232+
// 4. Done message
233+
events.push(Ok(Event::default().data("[DONE]")));
234+
235+
let stream = stream::iter(events);
236+
237+
return Sse::new(stream).into_response();
238+
}
239+
113240
let response = ChatCompletionResponse {
114241
id: format!("chatcmpl-{}", rand::thread_rng().gen::<u32>()),
115242
object: "chat.completion".to_string(),
@@ -134,5 +261,5 @@ pub async fn chat_completions(
134261
};
135262

136263
let headers = state.get_rate_limit_headers();
137-
Ok((headers, Json(json!(response))))
264+
(headers, Json(json!(response))).into_response()
138265
}

src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ use std::time::Duration;
1717
use tower_http::timeout::TimeoutLayer;
1818

1919
pub mod chat_completions;
20-
pub mod models;
2120
pub mod responses;
2221
pub mod server_state;
2322
use crate::server_state::ServerState;

src/models.rs

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)