Skip to content

Commit f27a412

Browse files
authored
(fix) Streaming and Tool Calling (#29)
* fixes to various models to get tool calling and streaming working * added missing assistant error * cleaned up content option in streaming
1 parent ffc1403 commit f27a412

File tree

9 files changed

+123
-56
lines changed

9 files changed

+123
-56
lines changed

backend/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/src/common/mapping/openrouter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ impl Into<ToolCall> for ChatCompletionRequestToolCall {
6565
fn into(self) -> ToolCall {
6666
ToolCall {
6767
id: Some(self.id),
68-
index: self.index,
68+
index: None,
6969
kind: Some(self.kind),
7070
function_call: self.function_call.into(),
7171
}

backend/src/common/types/chat_request.rs

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,9 @@ pub enum ChatCompletionRequestMessage{
6363
#[serde(skip_serializing_if = "Option::is_none")]
6464
name: Option<String>,
6565
},
66-
#[serde(rename_all = "camelCase")]
6766
Assistant {
68-
content: String,
69-
#[serde(skip_serializing_if = "Option::is_none")]
67+
content: Option<String>,
7068
tool_calls: Option<Vec<ChatCompletionRequestToolCall>>,
71-
#[serde(skip_serializing_if = "Option::is_none")]
7269
name: Option<String>,
7370
},
7471
Tool {
@@ -89,8 +86,6 @@ pub struct ChatCompletionRequestFunctionCall {
8986
pub struct ChatCompletionRequestToolCall {
9087
/// A unique identifier for the tool call.
9188
pub id: String,
92-
/// The index of the tool call in the list of tool calls
93-
pub index: u32,
9489
/// The type of call. It must be "function" for function calls.
9590
#[serde(rename = "type")]
9691
pub kind: String,
@@ -124,12 +119,47 @@ pub struct ChatCompletionRequestFunctionDescription {
124119
// Helper Methods for easy extraction
125120
impl ChatCompletionRequestMessage {
126121
/// Returns the content of the message regardless of its role
127-
pub fn content(&self) -> &str {
122+
pub fn content(&self) -> Option<String> {
123+
match self {
124+
ChatCompletionRequestMessage::System { content, .. } => Some(content.clone()),
125+
ChatCompletionRequestMessage::User { content, .. } => Some(content.clone()),
126+
ChatCompletionRequestMessage::Assistant { content, .. } => content.clone().map(|c| c),
127+
ChatCompletionRequestMessage::Tool { content, .. } => Some(content.clone()),
128+
}
129+
}
130+
131+
pub fn system_content(&self) -> String {
128132
match self {
129-
ChatCompletionRequestMessage::System { content, .. } => content,
130-
ChatCompletionRequestMessage::User { content, .. } => content,
131-
ChatCompletionRequestMessage::Assistant { content, .. } => content,
132-
ChatCompletionRequestMessage::Tool { content, .. } => content,
133+
ChatCompletionRequestMessage::System { content, .. } => content.clone(),
134+
_ => "".to_string()
135+
}
136+
}
137+
138+
pub fn user_content(&self) -> String {
139+
match self {
140+
ChatCompletionRequestMessage::User { content, .. } => content.clone(),
141+
_ => "".to_string()
142+
}
143+
}
144+
145+
pub fn assistant_content(&self) -> Option<String> {
146+
match self {
147+
ChatCompletionRequestMessage::Assistant { content, .. } => content.clone(),
148+
_ => None
149+
}
150+
}
151+
152+
pub fn tool_content(&self) -> String {
153+
match self {
154+
ChatCompletionRequestMessage::Tool { content, .. } => content.clone(),
155+
_ => "".to_string()
156+
}
157+
}
158+
159+
pub fn tool_call_id(&self) -> Option<String> {
160+
match self {
161+
ChatCompletionRequestMessage::Tool { tool_call_id, .. } => Some(tool_call_id.clone()),
162+
_ => None
133163
}
134164
}
135165

backend/src/common/types/chat_response.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,24 @@ pub struct LlmServiceChatCompletionResponseUsage {
3333
#[derive(Debug, Deserialize, Serialize)]
3434
pub struct LlmServiceChatCompletionResponseMessage {
3535
pub role: String,
36-
pub content: String,
36+
37+
pub content: Option<String>,
3738
#[serde(skip_serializing_if = "Option::is_none")]
3839
pub name: Option<String>,
3940
// Optionally include tool_calls when the assistant message contains a tool call.
4041
#[serde(skip_serializing_if = "Option::is_none")]
4142
pub tool_calls: Option<Vec<LlmServiceChatCompletionResponseToolCall>>,
43+
44+
#[serde(skip_serializing_if = "Option::is_none")]
45+
pub tool_call_id: Option<String>,
4246
}
4347

4448
#[derive(Debug, Deserialize, Serialize)]
4549
pub struct LlmServiceChatCompletionResponseToolCall {
4650
/// A unique identifier for the tool call.
4751
pub id: Option<String>,
4852
/// The index of the tool call in the list of tool calls
49-
pub index: u32,
53+
pub index: Option<u32>,
5054
/// The type of call. When streaming, the first chunk only will contain "function".
5155
#[serde(rename = "type")]
5256
pub kind: Option<String>,
@@ -69,7 +73,7 @@ impl LlmServiceChatCompletionResponse {
6973
/// Useful for handling streamed responses which are typically simpler.
7074
pub fn new_streamed(
7175
id: String,
72-
message_content: String,
76+
message_content: Option<String>,
7377
model: String,
7478
created: i64,
7579
prompt_tokens: Option<u32>,
@@ -84,6 +88,7 @@ impl LlmServiceChatCompletionResponse {
8488
content: message_content,
8589
name: None,
8690
tool_calls: None,
91+
tool_call_id: None,
8792
},
8893
finish_reason: Some("stop".to_string()),
8994
native_finish_reason: None,
@@ -122,6 +127,7 @@ impl From<ChatCompletionResponse> for LlmServiceChatCompletionResponse {
122127
role: choice.message.role,
123128
content: choice.message.content,
124129
name: choice.message.name,
130+
tool_call_id: choice.message.tool_call_id,
125131
tool_calls: choice.message.tool_calls.map(|tool_calls| {
126132
tool_calls.into_iter().map(|tool_call| {
127133
LlmServiceChatCompletionResponseToolCall {

backend/src/controllers/prompt_eval_run.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,13 @@ pub async fn execute_eval_run(
7373
.map_err(|_| AppError::InternalServerError("Something went wrong".to_string()))?;
7474

7575
if let Some(c) = res.0.choices.first() {
76+
// TODO: We should make the DB field nullable so we don't have to hack this
77+
let content = c.message.content.clone().map(|c| c.to_string()).unwrap_or("".to_string());
78+
7679
let eval_run = state
7780
.db
7881
.prompt_eval_run
79-
.create(&run_id, prompt_version_id, e.id, None, &c.message.content)
82+
.create(&run_id, prompt_version_id, e.id, None, &content)
8083
.await?;
8184

8285
eval_runs.push(eval_run);

backend/src/services/llm.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,30 @@ impl Llm {
4545
let res = self.send_request().await?;
4646

4747
if let Some(c) = res.0.choices.first() {
48+
// We don't need to validate the tool response
49+
if c.message.role == "tool" {
50+
return Ok(res);
51+
}
52+
53+
let content = match &c.message.content {
54+
Some(c) => c.to_string(),
55+
None => return Err(LlmError::MissingAssistantContent)
56+
};
57+
4858
// if we have a JSON schema available lets use it
4959
// Otherwise just make sure it's valid JSON and return
5060
match &self.props.request.response_format {
5161
Some(rf) => {
5262
match &rf.json_schema {
5363
Some(js) => {
54-
let is_valid = &self.validate_schema(&c.message.content, &js.schema)?;
64+
let is_valid = &self.validate_schema(&content, &js.schema)?;
5565
if !is_valid {
5666
tracing::error!("The schema was not valid");
5767
return Err(LlmError::InvalidJsonSchema);
5868
}
5969
},
6070
None => {
61-
let _json: serde_json::Value = serde_json::from_str(&c.message.content)?;
71+
let _json: serde_json::Value = serde_json::from_str(&content)?;
6272
}
6373
}
6474
},

backend/src/services/providers/openrouter.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ impl<'a> OpenrouterProvider<'a> {
3939
let messages = self.props.request.messages.iter().map(|msg| {
4040
openrouter_api::types::chat::Message {
4141
role: msg.role().to_string(),
42-
content: msg.content().to_string(),
42+
content: msg.content(),
4343
name: msg.name().map(|n| n.to_string()),
44+
tool_call_id: msg.tool_call_id(),
4445
tool_calls: match msg {
4546
ChatCompletionRequestMessage::Assistant { tool_calls, .. } => {
4647
match tool_calls {
@@ -77,8 +78,9 @@ impl<'a> OpenrouterProvider<'a> {
7778
let messages: Vec<openrouter_api::types::chat::Message> = self.props.request.messages.iter().map(|msg| {
7879
openrouter_api::types::chat::Message {
7980
role: msg.role().to_string(),
80-
content: msg.content().to_string(),
81+
content: msg.content(),
8182
name: msg.name().map(|n| n.to_string()),
83+
tool_call_id: msg.tool_call_id(),
8284
tool_calls: match msg {
8385
ChatCompletionRequestMessage::Assistant { tool_calls, .. } => {
8486
match tool_calls {
@@ -105,14 +107,14 @@ impl<'a> OpenrouterProvider<'a> {
105107
};
106108

107109
let mut stream = self.client.chat()?.chat_completion_stream(request);
108-
let mut content = String::new();
110+
let mut content: Option<String> = None;
109111
let mut prompt_tokens = 0;
110112
let mut completion_tokens = 0;
111113
let mut total_tokens = 0;
112114
let mut id = String::new();
113115

114116
while let Some(chunk) = stream.next().await {
115-
tracing::info!("chunk: {:?}", chunk);
117+
tracing::debug!("chunk: {:?}", chunk);
116118
match chunk {
117119
Ok(c) => {
118120
id = c.id.clone();
@@ -125,10 +127,15 @@ impl<'a> OpenrouterProvider<'a> {
125127

126128
if let Some(c) = &c.choices.first() {
127129
if let Some(c) = &c.delta.content {
128-
content += &c;
130+
match &mut content {
131+
Some(cnt) => cnt.push_str(&c),
132+
None => content = Some(c.to_string())
133+
}
129134
}
130135
}
131136

137+
// TODO: Capture tool calls
138+
132139
if let Err(_) = tx.send(Ok(c.into())).await {
133140
break;
134141
}

backend/src/services/types/llm_error.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ pub enum LlmError {
5959
PromptTooLong(usize, usize),
6060
#[error("Content policy violation: {0}")]
6161
ContentPolicy(String),
62+
#[error("Missing assistant content when expected")]
63+
MissingAssistantContent,
6264

6365
// Concurrency/Task errors
6466
#[error("MPSC Sender failed to send message in channel: {0}")]
@@ -96,7 +98,7 @@ pub enum LlmError {
9698
#[error("Serialization error: {0}")]
9799
SerializationError(String),
98100
#[error("Deserialization error: {0}")]
99-
DeserializationError(String)
101+
DeserializationError(String),
100102
}
101103

102104
impl From<openrouter_api::Error> for LlmError {

0 commit comments

Comments
 (0)