Skip to content

Commit eaadff1

Browse files
committed
fix: deserialize tool call args
Signed-off-by: Ryan Lempka <[email protected]>
1 parent eedfc3d commit eaadff1

File tree

1 file changed

+146
-1
lines changed
  • lib/llm/src/preprocessor/prompt/template

1 file changed

+146
-1
lines changed

lib/llm/src/preprocessor/prompt/template/oai.rs

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,49 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
121121
Value::from_serialize(&updated_messages)
122122
}
123123

124+
fn normalize_tool_arguments_in_messages(messages: &mut serde_json::Value) {
125+
// Deserialize tool call arguments from JSON strings to objects/arrays before template rendering
126+
// avoids double encoding and enables iteration
127+
let Some(msgs) = messages.as_array_mut() else {
128+
return;
129+
};
130+
131+
for msg in msgs.iter_mut() {
132+
if let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|v| v.as_array_mut()) {
133+
for tc in tool_calls {
134+
if let Some(function) = tc.get_mut("function").and_then(|v| v.as_object_mut()) {
135+
if let Some(args) = function.get_mut("arguments") {
136+
if let Some(s) = args.as_str() {
137+
if let Result::Ok(parsed) = serde_json::from_str(s) {
138+
*args = parsed;
139+
}
140+
}
141+
}
142+
}
143+
}
144+
}
145+
146+
if let Some(function_call) = msg.get_mut("function_call").and_then(|v| v.as_object_mut()) {
147+
if let Some(args) = function_call.get_mut("arguments") {
148+
if let Some(s) = args.as_str() {
149+
if let Result::Ok(parsed) = serde_json::from_str(s) {
150+
*args = parsed;
151+
}
152+
}
153+
}
154+
}
155+
}
156+
}
157+
124158
impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
125159
fn model(&self) -> String {
126160
self.inner.model.clone()
127161
}
128162

129163
fn messages(&self) -> Value {
130-
let messages_json = serde_json::to_value(&self.inner.messages).unwrap();
164+
let mut messages_json = serde_json::to_value(&self.inner.messages).unwrap();
165+
166+
normalize_tool_arguments_in_messages(&mut messages_json);
131167

132168
let needs_fixing = if let Some(arr) = messages_json.as_array() {
133169
arr.iter()
@@ -700,4 +736,113 @@ NORMAL MODE
700736
assert!(messages[0]["content"].is_array());
701737
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
702738
}
739+
740+
#[test]
741+
fn test_normalize_tool_arguments_tojson() {
742+
use minijinja::{Environment, context};
743+
744+
let tmpl = r#"{{ messages[0].tool_calls[0].function.arguments | tojson }}"#;
745+
746+
// Message with tool_calls containing JSON string arguments
747+
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
748+
"role": "assistant",
749+
"tool_calls": [{
750+
"type": "function",
751+
"function": {
752+
"name": "get_current_weather",
753+
"arguments": "{\"format\":\"celsius\",\"location\":\"San Francisco, CA\"}"
754+
}
755+
}]
756+
})]);
757+
758+
normalize_tool_arguments_in_messages(&mut messages);
759+
760+
let mut env = Environment::new();
761+
env.add_filter("tojson", super::super::tokcfg::tojson);
762+
env.add_template("t", tmpl).unwrap();
763+
let out = env
764+
.get_template("t")
765+
.unwrap()
766+
.render(context! { messages => messages.as_array().unwrap() })
767+
.unwrap();
768+
769+
// Should produce clean JSON without double-encoding
770+
assert_eq!(
771+
out,
772+
r#"{"format":"celsius","location":"San Francisco, CA"}"#
773+
);
774+
}
775+
776+
#[test]
777+
fn test_normalize_tool_arguments_items_loop() {
778+
use minijinja::{Environment, context};
779+
780+
let tmpl = r#"{% for k, v in messages[0].tool_calls[0].function.arguments|items %}{{k}}={{v}};{% endfor %}"#;
781+
782+
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
783+
"role": "assistant",
784+
"tool_calls": [{
785+
"type": "function",
786+
"function": {
787+
"name": "f",
788+
"arguments": "{\"a\":1,\"b\":\"x\"}"
789+
}
790+
}]
791+
})]);
792+
793+
normalize_tool_arguments_in_messages(&mut messages);
794+
795+
let mut env = Environment::new();
796+
env.add_template("t", tmpl).unwrap();
797+
let out = env
798+
.get_template("t")
799+
.unwrap()
800+
.render(context! { messages => messages.as_array().unwrap() })
801+
.unwrap();
802+
803+
// Order-insensitive check: either a=1;b=x; or b=x;a=1;
804+
assert!(out == "a=1;b=x;" || out == "b=x;a=1;");
805+
}
806+
807+
#[test]
808+
fn test_normalize_tool_arguments_legacy_function_call() {
809+
// Test deprecated function_call format (OpenAI compat)
810+
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
811+
"role": "assistant",
812+
"function_call": {
813+
"name": "get_weather",
814+
"arguments": "{\"location\":\"NYC\"}"
815+
}
816+
})]);
817+
818+
normalize_tool_arguments_in_messages(&mut messages);
819+
820+
assert_eq!(
821+
messages[0]["function_call"]["arguments"],
822+
serde_json::json!({"location": "NYC"})
823+
);
824+
}
825+
826+
#[test]
827+
fn test_normalize_tool_arguments_malformed_json_passthrough() {
828+
// Malformed JSON should be left as a string
829+
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
830+
"role": "assistant",
831+
"tool_calls": [{
832+
"type": "function",
833+
"function": {
834+
"name": "f",
835+
"arguments": "not valid json at all"
836+
}
837+
}]
838+
})]);
839+
840+
normalize_tool_arguments_in_messages(&mut messages);
841+
842+
// Should remain as string
843+
assert_eq!(
844+
messages[0]["tool_calls"][0]["function"]["arguments"],
845+
serde_json::Value::String("not valid json at all".to_string())
846+
);
847+
}
703848
}

0 commit comments

Comments
 (0)