Skip to content

Commit 51c4fe6

Browse files
authored
fix: deserialize tool call args (#4176)
Signed-off-by: Ryan Lempka <[email protected]>
1 parent 441473c commit 51c4fe6

File tree

1 file changed

+185
-1
lines changed
  • lib/llm/src/preprocessor/prompt/template

1 file changed

+185
-1
lines changed

lib/llm/src/preprocessor/prompt/template/oai.rs

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use super::*;
55

66
use minijinja::{context, value::Value};
7+
use std::result::Result::Ok;
78

89
use crate::protocols::openai::{
910
chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest,
@@ -121,6 +122,36 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
121122
Value::from_serialize(&updated_messages)
122123
}
123124

125+
fn normalize_tool_arguments_in_messages(messages: &mut serde_json::Value) {
126+
// Deserialize tool call arguments from JSON strings to objects/arrays before template rendering
127+
// avoids double encoding and enables iteration
128+
let Some(msgs) = messages.as_array_mut() else {
129+
return;
130+
};
131+
132+
for msg in msgs.iter_mut() {
133+
if let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|v| v.as_array_mut()) {
134+
for tc in tool_calls {
135+
if let Some(function) = tc.get_mut("function").and_then(|v| v.as_object_mut())
136+
&& let Some(args) = function.get_mut("arguments")
137+
&& let Some(s) = args.as_str()
138+
&& let Ok(parsed) = serde_json::from_str(s)
139+
{
140+
*args = parsed;
141+
}
142+
}
143+
}
144+
145+
if let Some(function_call) = msg.get_mut("function_call").and_then(|v| v.as_object_mut())
146+
&& let Some(args) = function_call.get_mut("arguments")
147+
&& let Some(s) = args.as_str()
148+
&& let Ok(parsed) = serde_json::from_str(s)
149+
{
150+
*args = parsed;
151+
}
152+
}
153+
}
154+
124155
impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
125156
fn model(&self) -> String {
126157
self.inner.model.clone()
@@ -267,8 +298,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
267298
add_generation_prompt
268299
);
269300

301+
let messages_canonical = req.messages();
302+
let mut messages_for_template: serde_json::Value =
303+
serde_json::to_value(&messages_canonical).unwrap();
304+
normalize_tool_arguments_in_messages(&mut messages_for_template);
305+
270306
let ctx = context! {
271-
messages => req.messages(),
307+
messages => messages_for_template,
272308
tools => tools,
273309
bos_token => self.config.bos_tok(),
274310
eos_token => self.config.eos_tok(),
@@ -298,6 +334,7 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
298334
mod tests {
299335
use super::*;
300336
use dynamo_async_openai::types::ChatCompletionRequestMessage as Msg;
337+
use minijinja::{Environment, context};
301338

302339
#[test]
303340
fn test_may_be_fix_tool_schema_missing_type_and_properties() {
@@ -705,6 +742,153 @@ NORMAL MODE
705742
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
706743
}
707744

745+
#[test]
746+
fn test_normalize_tool_arguments_tojson() {
747+
let tmpl = r#"{{ messages[0].tool_calls[0].function.arguments | tojson }}"#;
748+
749+
// Message with tool_calls containing JSON string arguments
750+
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
751+
"role": "assistant",
752+
"tool_calls": [{
753+
"type": "function",
754+
"function": {
755+
"name": "get_current_weather",
756+
"arguments": "{\"format\":\"celsius\",\"location\":\"San Francisco, CA\"}"
757+
}
758+
}]
759+
})]);
760+
761+
normalize_tool_arguments_in_messages(&mut messages);
762+
763+
let mut env = Environment::new();
764+
env.add_filter("tojson", super::super::tokcfg::tojson);
765+
env.add_template("t", tmpl).unwrap();
766+
let out = env
767+
.get_template("t")
768+
.unwrap()
769+
.render(context! { messages => messages.as_array().unwrap() })
770+
.unwrap();
771+
772+
// Should produce clean JSON without double-encoding
773+
assert_eq!(
774+
out,
775+
r#"{"format":"celsius","location":"San Francisco, CA"}"#
776+
);
777+
}
778+
779+
#[test]
780+
fn test_normalize_tool_arguments_items_loop() {
781+
let tmpl = r#"{% for k, v in messages[0].tool_calls[0].function.arguments|items %}{{k}}={{v}};{% endfor %}"#;
782+
783+
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
784+
"role": "assistant",
785+
"tool_calls": [{
786+
"type": "function",
787+
"function": {
788+
"name": "f",
789+
"arguments": "{\"a\":1,\"b\":\"x\"}"
790+
}
791+
}]
792+
})]);
793+
794+
normalize_tool_arguments_in_messages(&mut messages);
795+
796+
let mut env = Environment::new();
797+
env.add_template("t", tmpl).unwrap();
798+
let out = env
799+
.get_template("t")
800+
.unwrap()
801+
.render(context! { messages => messages.as_array().unwrap() })
802+
.unwrap();
803+
804+
assert!(out == "a=1;b=x;" || out == "b=x;a=1;");
805+
}
806+
807+
#[test]
808+
fn test_normalize_tool_arguments_legacy_function_call() {
809+
// Test deprecated function_call format (OpenAI compat)
810+
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
811+
"role": "assistant",
812+
"function_call": {
813+
"name": "get_weather",
814+
"arguments": "{\"location\":\"NYC\"}"
815+
}
816+
})]);
817+
818+
normalize_tool_arguments_in_messages(&mut messages);
819+
820+
assert_eq!(
821+
messages[0]["function_call"]["arguments"],
822+
serde_json::json!({"location": "NYC"})
823+
);
824+
}
825+
826+
#[test]
827+
fn test_normalize_tool_arguments_malformed_json_passthrough() {
828+
// Malformed JSON should be left as a string
829+
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
830+
"role": "assistant",
831+
"tool_calls": [{
832+
"type": "function",
833+
"function": {
834+
"name": "f",
835+
"arguments": "not valid json at all"
836+
}
837+
}]
838+
})]);
839+
840+
normalize_tool_arguments_in_messages(&mut messages);
841+
842+
assert_eq!(
843+
messages[0]["tool_calls"][0]["function"]["arguments"],
844+
serde_json::Value::String("not valid json at all".to_string())
845+
);
846+
}
847+
848+
#[test]
849+
fn test_normalize_tool_arguments_with_multimodal_content() {
850+
let json_str = r#"{
851+
"model": "gpt-4o",
852+
"messages": [
853+
{
854+
"role": "user",
855+
"content": [
856+
{"type": "text", "text": "Check this:"},
857+
{"type": "video_url", "video_url": {"url": "https://example.com/vid.mp4"}},
858+
{"type": "text", "text": "Interesting?"}
859+
]
860+
},
861+
{
862+
"role": "assistant",
863+
"tool_calls": [{
864+
"id": "call_123",
865+
"type": "function",
866+
"function": {
867+
"name": "analyze_video",
868+
"arguments": "{\"url\":\"https://example.com/vid.mp4\",\"format\":\"mp4\"}"
869+
}
870+
}]
871+
}
872+
]
873+
}"#;
874+
875+
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
876+
let mut messages = serde_json::to_value(request.messages()).unwrap();
877+
878+
normalize_tool_arguments_in_messages(&mut messages);
879+
880+
// Multimodal content preserved as array
881+
assert!(messages[0]["content"].is_array());
882+
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
883+
884+
// Tool arguments deserialized to object
885+
assert!(messages[1]["tool_calls"][0]["function"]["arguments"].is_object());
886+
assert_eq!(
887+
messages[1]["tool_calls"][0]["function"]["arguments"]["url"],
888+
"https://example.com/vid.mp4"
889+
);
890+
}
891+
708892
fn user() -> Msg {
709893
Msg::User(Default::default())
710894
}

0 commit comments

Comments
 (0)