Skip to content

Commit b9d093d

Browse files
committed
fix: deserialize tool call args
Signed-off-by: Ryan Lempka <[email protected]>
1 parent 3fd0ab3 commit b9d093d

File tree

1 file changed

+147
-1
lines changed
  • lib/llm/src/preprocessor/prompt/template

1 file changed

+147
-1
lines changed

lib/llm/src/preprocessor/prompt/template/oai.rs

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,49 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
121121
Value::from_serialize(&updated_messages)
122122
}
123123

124+
fn normalize_tool_arguments_in_messages(messages: &mut serde_json::Value) {
125+
// Deserialize tool call arguments from JSON strings to objects/arrays before template rendering
126+
// avoids double encoding and enables iteration
127+
let Some(msgs) = messages.as_array_mut() else {
128+
return;
129+
};
130+
131+
for msg in msgs.iter_mut() {
132+
if let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|v| v.as_array_mut()) {
133+
for tc in tool_calls {
134+
if let Some(function) = tc.get_mut("function").and_then(|v| v.as_object_mut()) {
135+
if let Some(args) = function.get_mut("arguments") {
136+
if let Some(s) = args.as_str() {
137+
if let Result::Ok(parsed) = serde_json::from_str(s) {
138+
*args = parsed;
139+
}
140+
}
141+
}
142+
}
143+
}
144+
}
145+
146+
if let Some(function_call) = msg.get_mut("function_call").and_then(|v| v.as_object_mut()) {
147+
if let Some(args) = function_call.get_mut("arguments") {
148+
if let Some(s) = args.as_str() {
149+
if let Result::Ok(parsed) = serde_json::from_str(s) {
150+
*args = parsed;
151+
}
152+
}
153+
}
154+
}
155+
}
156+
}
157+
124158
impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
125159
fn model(&self) -> String {
126160
self.inner.model.clone()
127161
}
128162

129163
fn messages(&self) -> Value {
130-
let messages_json = serde_json::to_value(&self.inner.messages).unwrap();
164+
let mut messages_json = serde_json::to_value(&self.inner.messages).unwrap();
165+
166+
normalize_tool_arguments_in_messages(&mut messages_json);
131167

132168
let needs_fixing = if let Some(arr) = messages_json.as_array() {
133169
arr.iter()
@@ -746,4 +782,114 @@ NORMAL MODE
746782
let s = dummy_state(vec![]);
747783
assert!(s.should_add_generation_prompt());
748784
}
785+
786+
#[test]
787+
fn test_normalize_tool_arguments_tojson() {
788+
use minijinja::{Environment, context};
789+
790+
let tmpl = r#"{{ messages[0].tool_calls[0].function.arguments | tojson }}"#;
791+
792+
// Message with tool_calls containing JSON string arguments
793+
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
794+
"role": "assistant",
795+
"tool_calls": [{
796+
"type": "function",
797+
"function": {
798+
"name": "get_current_weather",
799+
"arguments": "{\"format\":\"celsius\",\"location\":\"San Francisco, CA\"}"
800+
}
801+
}]
802+
})]);
803+
804+
normalize_tool_arguments_in_messages(&mut messages);
805+
806+
let mut env = Environment::new();
807+
env.add_filter("tojson", super::super::tokcfg::tojson);
808+
env.add_template("t", tmpl).unwrap();
809+
let out = env
810+
.get_template("t")
811+
.unwrap()
812+
.render(context! { messages => messages.as_array().unwrap() })
813+
.unwrap();
814+
815+
// Should produce clean JSON without double-encoding
816+
assert_eq!(
817+
out,
818+
r#"{"format":"celsius","location":"San Francisco, CA"}"#
819+
);
820+
}
821+
822+
#[test]
823+
fn test_normalize_tool_arguments_items_loop() {
824+
use minijinja::{Environment, context};
825+
826+
let tmpl = r#"{% for k, v in messages[0].tool_calls[0].function.arguments|items %}{{k}}={{v}};{% endfor %}"#;
827+
828+
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
829+
"role": "assistant",
830+
"tool_calls": [{
831+
"type": "function",
832+
"function": {
833+
"name": "f",
834+
"arguments": "{\"a\":1,\"b\":\"x\"}"
835+
}
836+
}]
837+
})]);
838+
839+
normalize_tool_arguments_in_messages(&mut messages);
840+
841+
let mut env = Environment::new();
842+
env.add_template("t", tmpl).unwrap();
843+
let out = env
844+
.get_template("t")
845+
.unwrap()
846+
.render(context! { messages => messages.as_array().unwrap() })
847+
.unwrap();
848+
849+
// Order-insensitive check: either a=1;b=x; or b=x;a=1;
850+
assert!(out == "a=1;b=x;" || out == "b=x;a=1;");
851+
}
852+
853+
#[test]
854+
fn test_normalize_tool_arguments_legacy_function_call() {
855+
// Test deprecated function_call format (OpenAI compat)
856+
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
857+
"role": "assistant",
858+
"function_call": {
859+
"name": "get_weather",
860+
"arguments": "{\"location\":\"NYC\"}"
861+
}
862+
})]);
863+
864+
normalize_tool_arguments_in_messages(&mut messages);
865+
866+
assert_eq!(
867+
messages[0]["function_call"]["arguments"],
868+
serde_json::json!({"location": "NYC"})
869+
);
870+
}
871+
872+
#[test]
873+
fn test_normalize_tool_arguments_malformed_json_passthrough() {
874+
// Malformed JSON should be left as a string
875+
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
876+
"role": "assistant",
877+
"tool_calls": [{
878+
"type": "function",
879+
"function": {
880+
"name": "f",
881+
"arguments": "not valid json at all"
882+
}
883+
}]
884+
})]);
885+
886+
normalize_tool_arguments_in_messages(&mut messages);
887+
888+
// Should remain as string
889+
assert_eq!(
890+
messages[0]["tool_calls"][0]["function"]["arguments"],
891+
serde_json::Value::String("not valid json at all".to_string())
892+
);
893+
894+
}
749895
}

0 commit comments

Comments
 (0)