Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 185 additions & 1 deletion lib/llm/src/preprocessor/prompt/template/oai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use super::*;

use minijinja::{context, value::Value};
use std::result::Result::Ok;

use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest,
Expand Down Expand Up @@ -121,6 +122,36 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
Value::from_serialize(&updated_messages)
}

fn normalize_tool_arguments_in_messages(messages: &mut serde_json::Value) {
// Deserialize tool call arguments from JSON strings to objects/arrays before template rendering
// avoids double encoding and enables iteration
let Some(msgs) = messages.as_array_mut() else {
return;
};

for msg in msgs.iter_mut() {
if let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|v| v.as_array_mut()) {
for tc in tool_calls {
if let Some(function) = tc.get_mut("function").and_then(|v| v.as_object_mut())
&& let Some(args) = function.get_mut("arguments")
&& let Some(s) = args.as_str()
&& let Ok(parsed) = serde_json::from_str(s)
{
*args = parsed;
}
}
}

if let Some(function_call) = msg.get_mut("function_call").and_then(|v| v.as_object_mut())
&& let Some(args) = function_call.get_mut("arguments")
&& let Some(s) = args.as_str()
&& let Ok(parsed) = serde_json::from_str(s)
{
*args = parsed;
}
}
}

impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn model(&self) -> String {
self.inner.model.clone()
Expand Down Expand Up @@ -267,8 +298,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
add_generation_prompt
);

let messages_canonical = req.messages();
let mut messages_for_template: serde_json::Value =
serde_json::to_value(&messages_canonical).unwrap();
normalize_tool_arguments_in_messages(&mut messages_for_template);

let ctx = context! {
messages => req.messages(),
messages => messages_for_template,
tools => tools,
bos_token => self.config.bos_tok(),
eos_token => self.config.eos_tok(),
Expand Down Expand Up @@ -298,6 +334,7 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
mod tests {
use super::*;
use dynamo_async_openai::types::ChatCompletionRequestMessage as Msg;
use minijinja::{Environment, context};

#[test]
fn test_may_be_fix_tool_schema_missing_type_and_properties() {
Expand Down Expand Up @@ -705,6 +742,153 @@ NORMAL MODE
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
}

#[test]
fn test_normalize_tool_arguments_tojson() {
let tmpl = r#"{{ messages[0].tool_calls[0].function.arguments | tojson }}"#;

// Message with tool_calls containing JSON string arguments
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
"role": "assistant",
"tool_calls": [{
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\"format\":\"celsius\",\"location\":\"San Francisco, CA\"}"
}
}]
})]);

normalize_tool_arguments_in_messages(&mut messages);

let mut env = Environment::new();
env.add_filter("tojson", super::super::tokcfg::tojson);
env.add_template("t", tmpl).unwrap();
let out = env
.get_template("t")
.unwrap()
.render(context! { messages => messages.as_array().unwrap() })
.unwrap();

// Should produce clean JSON without double-encoding
assert_eq!(
out,
r#"{"format":"celsius","location":"San Francisco, CA"}"#
);
}

#[test]
fn test_normalize_tool_arguments_items_loop() {
let tmpl = r#"{% for k, v in messages[0].tool_calls[0].function.arguments|items %}{{k}}={{v}};{% endfor %}"#;

let mut messages = serde_json::Value::Array(vec![serde_json::json!({
"role": "assistant",
"tool_calls": [{
"type": "function",
"function": {
"name": "f",
"arguments": "{\"a\":1,\"b\":\"x\"}"
}
}]
})]);

normalize_tool_arguments_in_messages(&mut messages);

let mut env = Environment::new();
env.add_template("t", tmpl).unwrap();
let out = env
.get_template("t")
.unwrap()
.render(context! { messages => messages.as_array().unwrap() })
.unwrap();

assert!(out == "a=1;b=x;" || out == "b=x;a=1;");
}

#[test]
fn test_normalize_tool_arguments_legacy_function_call() {
// Test deprecated function_call format (OpenAI compat)
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
"role": "assistant",
"function_call": {
"name": "get_weather",
"arguments": "{\"location\":\"NYC\"}"
}
})]);

normalize_tool_arguments_in_messages(&mut messages);

assert_eq!(
messages[0]["function_call"]["arguments"],
serde_json::json!({"location": "NYC"})
);
}

#[test]
fn test_normalize_tool_arguments_malformed_json_passthrough() {
// Malformed JSON should be left as a string
let mut messages = serde_json::Value::Array(vec![serde_json::json!({
"role": "assistant",
"tool_calls": [{
"type": "function",
"function": {
"name": "f",
"arguments": "not valid json at all"
}
}]
})]);

normalize_tool_arguments_in_messages(&mut messages);

assert_eq!(
messages[0]["tool_calls"][0]["function"]["arguments"],
serde_json::Value::String("not valid json at all".to_string())
);
}

#[test]
fn test_normalize_tool_arguments_with_multimodal_content() {
let json_str = r#"{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Check this:"},
{"type": "video_url", "video_url": {"url": "https://example.com/vid.mp4"}},
{"type": "text", "text": "Interesting?"}
]
},
{
"role": "assistant",
"tool_calls": [{
"id": "call_123",
"type": "function",
"function": {
"name": "analyze_video",
"arguments": "{\"url\":\"https://example.com/vid.mp4\",\"format\":\"mp4\"}"
}
}]
}
]
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let mut messages = serde_json::to_value(request.messages()).unwrap();

normalize_tool_arguments_in_messages(&mut messages);

// Multimodal content preserved as array
assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);

// Tool arguments deserialized to object
assert!(messages[1]["tool_calls"][0]["function"]["arguments"].is_object());
assert_eq!(
messages[1]["tool_calls"][0]["function"]["arguments"]["url"],
"https://example.com/vid.mp4"
);
}

fn user() -> Msg {
Msg::User(Default::default())
}
Expand Down
Loading