diff --git a/lib/llm/src/preprocessor/prompt/template/oai.rs b/lib/llm/src/preprocessor/prompt/template/oai.rs index 78c6d35eec..9d02e3a698 100644 --- a/lib/llm/src/preprocessor/prompt/template/oai.rs +++ b/lib/llm/src/preprocessor/prompt/template/oai.rs @@ -4,6 +4,7 @@ use super::*; use minijinja::{context, value::Value}; +use std::result::Result::Ok; use crate::protocols::openai::{ chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest, @@ -121,6 +122,36 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value { Value::from_serialize(&updated_messages) } +fn normalize_tool_arguments_in_messages(messages: &mut serde_json::Value) { + // Deserialize tool call arguments from JSON strings to objects/arrays before template rendering + // avoids double encoding and enables iteration + let Some(msgs) = messages.as_array_mut() else { + return; + }; + + for msg in msgs.iter_mut() { + if let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|v| v.as_array_mut()) { + for tc in tool_calls { + if let Some(function) = tc.get_mut("function").and_then(|v| v.as_object_mut()) + && let Some(args) = function.get_mut("arguments") + && let Some(s) = args.as_str() + && let Ok(parsed) = serde_json::from_str(s) + { + *args = parsed; + } + } + } + + if let Some(function_call) = msg.get_mut("function_call").and_then(|v| v.as_object_mut()) + && let Some(args) = function_call.get_mut("arguments") + && let Some(s) = args.as_str() + && let Ok(parsed) = serde_json::from_str(s) + { + *args = parsed; + } + } +} + impl OAIChatLikeRequest for NvCreateChatCompletionRequest { fn model(&self) -> String { self.inner.model.clone() @@ -267,8 +298,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { add_generation_prompt ); + let messages_canonical = req.messages(); + let mut messages_for_template: serde_json::Value = + serde_json::to_value(&messages_canonical).unwrap(); + normalize_tool_arguments_in_messages(&mut messages_for_template); + let ctx = context! { - messages => req.messages(), + messages => messages_for_template, tools => tools, bos_token => self.config.bos_tok(), eos_token => self.config.eos_tok(), @@ -298,6 +334,7 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { mod tests { use super::*; use dynamo_async_openai::types::ChatCompletionRequestMessage as Msg; + use minijinja::{Environment, context}; #[test] fn test_may_be_fix_tool_schema_missing_type_and_properties() { @@ -705,6 +742,153 @@ NORMAL MODE assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3); } + #[test] + fn test_normalize_tool_arguments_tojson() { + let tmpl = r#"{{ messages[0].tool_calls[0].function.arguments | tojson }}"#; + + // Message with tool_calls containing JSON string arguments + let mut messages = serde_json::Value::Array(vec![serde_json::json!({ + "role": "assistant", + "tool_calls": [{ + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\"format\":\"celsius\",\"location\":\"San Francisco, CA\"}" + } + }] + })]); + + normalize_tool_arguments_in_messages(&mut messages); + + let mut env = Environment::new(); + env.add_filter("tojson", super::super::tokcfg::tojson); + env.add_template("t", tmpl).unwrap(); + let out = env + .get_template("t") + .unwrap() + .render(context! { messages => messages.as_array().unwrap() }) + .unwrap(); + + // Should produce clean JSON without double-encoding + assert_eq!( + out, + r#"{"format":"celsius","location":"San Francisco, CA"}"# + ); + } + + #[test] + fn test_normalize_tool_arguments_items_loop() { + let tmpl = r#"{% for k, v in messages[0].tool_calls[0].function.arguments|items %}{{k}}={{v}};{% endfor %}"#; + + let mut messages = serde_json::Value::Array(vec![serde_json::json!({ + "role": "assistant", + "tool_calls": [{ + "type": "function", + "function": { + "name": "f", + "arguments": "{\"a\":1,\"b\":\"x\"}" + } + }] + })]); + + normalize_tool_arguments_in_messages(&mut messages); + + let mut env = Environment::new(); + env.add_template("t", tmpl).unwrap(); + let out = env + .get_template("t") + .unwrap() + .render(context! { messages => messages.as_array().unwrap() }) + .unwrap(); + + assert!(out == "a=1;b=x;" || out == "b=x;a=1;"); + } + + #[test] + fn test_normalize_tool_arguments_legacy_function_call() { + // Test deprecated function_call format (OpenAI compat) + let mut messages = serde_json::Value::Array(vec![serde_json::json!({ + "role": "assistant", + "function_call": { + "name": "get_weather", + "arguments": "{\"location\":\"NYC\"}" + } + })]); + + normalize_tool_arguments_in_messages(&mut messages); + + assert_eq!( + messages[0]["function_call"]["arguments"], + serde_json::json!({"location": "NYC"}) + ); + } + + #[test] + fn test_normalize_tool_arguments_malformed_json_passthrough() { + // Malformed JSON should be left as a string + let mut messages = serde_json::Value::Array(vec![serde_json::json!({ + "role": "assistant", + "tool_calls": [{ + "type": "function", + "function": { + "name": "f", + "arguments": "not valid json at all" + } + }] + })]); + + normalize_tool_arguments_in_messages(&mut messages); + + assert_eq!( + messages[0]["tool_calls"][0]["function"]["arguments"], + serde_json::Value::String("not valid json at all".to_string()) + ); + } + + #[test] + fn test_normalize_tool_arguments_with_multimodal_content() { + let json_str = r#"{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Check this:"}, + {"type": "video_url", "video_url": {"url": "https://example.com/vid.mp4"}}, + {"type": "text", "text": "Interesting?"} + ] + }, + { + "role": "assistant", + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": { + "name": "analyze_video", + "arguments": "{\"url\":\"https://example.com/vid.mp4\",\"format\":\"mp4\"}" + } + }] + } + ] + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + let mut messages = serde_json::to_value(request.messages()).unwrap(); + + normalize_tool_arguments_in_messages(&mut messages); + + // Multimodal content preserved as array + assert!(messages[0]["content"].is_array()); + assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3); + + // Tool arguments deserialized to object + assert!(messages[1]["tool_calls"][0]["function"]["arguments"].is_object()); + assert_eq!( + messages[1]["tool_calls"][0]["function"]["arguments"]["url"], + "https://example.com/vid.mp4" + ); + } + fn user() -> Msg { Msg::User(Default::default()) }