|
4 | 4 | use super::*; |
5 | 5 |
|
6 | 6 | use minijinja::{context, value::Value}; |
| 7 | +use std::result::Result::Ok; |
7 | 8 |
|
8 | 9 | use crate::protocols::openai::{ |
9 | 10 | chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest, |
@@ -121,6 +122,36 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value { |
121 | 122 | Value::from_serialize(&updated_messages) |
122 | 123 | } |
123 | 124 |
|
| 125 | +fn normalize_tool_arguments_in_messages(messages: &mut serde_json::Value) { |
| 126 | + // Deserialize tool call arguments from JSON strings to objects/arrays before template rendering |
| 127 | + // avoids double encoding and enables iteration |
| 128 | + let Some(msgs) = messages.as_array_mut() else { |
| 129 | + return; |
| 130 | + }; |
| 131 | + |
| 132 | + for msg in msgs.iter_mut() { |
| 133 | + if let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|v| v.as_array_mut()) { |
| 134 | + for tc in tool_calls { |
| 135 | + if let Some(function) = tc.get_mut("function").and_then(|v| v.as_object_mut()) |
| 136 | + && let Some(args) = function.get_mut("arguments") |
| 137 | + && let Some(s) = args.as_str() |
| 138 | + && let Ok(parsed) = serde_json::from_str(s) |
| 139 | + { |
| 140 | + *args = parsed; |
| 141 | + } |
| 142 | + } |
| 143 | + } |
| 144 | + |
| 145 | + if let Some(function_call) = msg.get_mut("function_call").and_then(|v| v.as_object_mut()) |
| 146 | + && let Some(args) = function_call.get_mut("arguments") |
| 147 | + && let Some(s) = args.as_str() |
| 148 | + && let Ok(parsed) = serde_json::from_str(s) |
| 149 | + { |
| 150 | + *args = parsed; |
| 151 | + } |
| 152 | + } |
| 153 | +} |
| 154 | + |
124 | 155 | impl OAIChatLikeRequest for NvCreateChatCompletionRequest { |
125 | 156 | fn model(&self) -> String { |
126 | 157 | self.inner.model.clone() |
@@ -267,8 +298,13 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { |
267 | 298 | add_generation_prompt |
268 | 299 | ); |
269 | 300 |
|
| 301 | + let messages_canonical = req.messages(); |
| 302 | + let mut messages_for_template: serde_json::Value = |
| 303 | + serde_json::to_value(&messages_canonical).unwrap(); |
| 304 | + normalize_tool_arguments_in_messages(&mut messages_for_template); |
| 305 | + |
270 | 306 | let ctx = context! { |
271 | | - messages => req.messages(), |
| 307 | + messages => messages_for_template, |
272 | 308 | tools => tools, |
273 | 309 | bos_token => self.config.bos_tok(), |
274 | 310 | eos_token => self.config.eos_tok(), |
@@ -298,6 +334,7 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { |
298 | 334 | mod tests { |
299 | 335 | use super::*; |
300 | 336 | use dynamo_async_openai::types::ChatCompletionRequestMessage as Msg; |
| 337 | + use minijinja::{Environment, context}; |
301 | 338 |
|
302 | 339 | #[test] |
303 | 340 | fn test_may_be_fix_tool_schema_missing_type_and_properties() { |
@@ -705,6 +742,153 @@ NORMAL MODE |
705 | 742 | assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3); |
706 | 743 | } |
707 | 744 |
|
| 745 | + #[test] |
| 746 | + fn test_normalize_tool_arguments_tojson() { |
| 747 | + let tmpl = r#"{{ messages[0].tool_calls[0].function.arguments | tojson }}"#; |
| 748 | + |
| 749 | + // Message with tool_calls containing JSON string arguments |
| 750 | + let mut messages = serde_json::Value::Array(vec![serde_json::json!({ |
| 751 | + "role": "assistant", |
| 752 | + "tool_calls": [{ |
| 753 | + "type": "function", |
| 754 | + "function": { |
| 755 | + "name": "get_current_weather", |
| 756 | + "arguments": "{\"format\":\"celsius\",\"location\":\"San Francisco, CA\"}" |
| 757 | + } |
| 758 | + }] |
| 759 | + })]); |
| 760 | + |
| 761 | + normalize_tool_arguments_in_messages(&mut messages); |
| 762 | + |
| 763 | + let mut env = Environment::new(); |
| 764 | + env.add_filter("tojson", super::super::tokcfg::tojson); |
| 765 | + env.add_template("t", tmpl).unwrap(); |
| 766 | + let out = env |
| 767 | + .get_template("t") |
| 768 | + .unwrap() |
| 769 | + .render(context! { messages => messages.as_array().unwrap() }) |
| 770 | + .unwrap(); |
| 771 | + |
| 772 | + // Should produce clean JSON without double-encoding |
| 773 | + assert_eq!( |
| 774 | + out, |
| 775 | + r#"{"format":"celsius","location":"San Francisco, CA"}"# |
| 776 | + ); |
| 777 | + } |
| 778 | + |
| 779 | + #[test] |
| 780 | + fn test_normalize_tool_arguments_items_loop() { |
| 781 | + let tmpl = r#"{% for k, v in messages[0].tool_calls[0].function.arguments|items %}{{k}}={{v}};{% endfor %}"#; |
| 782 | + |
| 783 | + let mut messages = serde_json::Value::Array(vec![serde_json::json!({ |
| 784 | + "role": "assistant", |
| 785 | + "tool_calls": [{ |
| 786 | + "type": "function", |
| 787 | + "function": { |
| 788 | + "name": "f", |
| 789 | + "arguments": "{\"a\":1,\"b\":\"x\"}" |
| 790 | + } |
| 791 | + }] |
| 792 | + })]); |
| 793 | + |
| 794 | + normalize_tool_arguments_in_messages(&mut messages); |
| 795 | + |
| 796 | + let mut env = Environment::new(); |
| 797 | + env.add_template("t", tmpl).unwrap(); |
| 798 | + let out = env |
| 799 | + .get_template("t") |
| 800 | + .unwrap() |
| 801 | + .render(context! { messages => messages.as_array().unwrap() }) |
| 802 | + .unwrap(); |
| 803 | + |
| 804 | + assert!(out == "a=1;b=x;" || out == "b=x;a=1;"); |
| 805 | + } |
| 806 | + |
| 807 | + #[test] |
| 808 | + fn test_normalize_tool_arguments_legacy_function_call() { |
| 809 | + // Test deprecated function_call format (OpenAI compat) |
| 810 | + let mut messages = serde_json::Value::Array(vec![serde_json::json!({ |
| 811 | + "role": "assistant", |
| 812 | + "function_call": { |
| 813 | + "name": "get_weather", |
| 814 | + "arguments": "{\"location\":\"NYC\"}" |
| 815 | + } |
| 816 | + })]); |
| 817 | + |
| 818 | + normalize_tool_arguments_in_messages(&mut messages); |
| 819 | + |
| 820 | + assert_eq!( |
| 821 | + messages[0]["function_call"]["arguments"], |
| 822 | + serde_json::json!({"location": "NYC"}) |
| 823 | + ); |
| 824 | + } |
| 825 | + |
| 826 | + #[test] |
| 827 | + fn test_normalize_tool_arguments_malformed_json_passthrough() { |
| 828 | + // Malformed JSON should be left as a string |
| 829 | + let mut messages = serde_json::Value::Array(vec![serde_json::json!({ |
| 830 | + "role": "assistant", |
| 831 | + "tool_calls": [{ |
| 832 | + "type": "function", |
| 833 | + "function": { |
| 834 | + "name": "f", |
| 835 | + "arguments": "not valid json at all" |
| 836 | + } |
| 837 | + }] |
| 838 | + })]); |
| 839 | + |
| 840 | + normalize_tool_arguments_in_messages(&mut messages); |
| 841 | + |
| 842 | + assert_eq!( |
| 843 | + messages[0]["tool_calls"][0]["function"]["arguments"], |
| 844 | + serde_json::Value::String("not valid json at all".to_string()) |
| 845 | + ); |
| 846 | + } |
| 847 | + |
| 848 | + #[test] |
| 849 | + fn test_normalize_tool_arguments_with_multimodal_content() { |
| 850 | + let json_str = r#"{ |
| 851 | + "model": "gpt-4o", |
| 852 | + "messages": [ |
| 853 | + { |
| 854 | + "role": "user", |
| 855 | + "content": [ |
| 856 | + {"type": "text", "text": "Check this:"}, |
| 857 | + {"type": "video_url", "video_url": {"url": "https://example.com/vid.mp4"}}, |
| 858 | + {"type": "text", "text": "Interesting?"} |
| 859 | + ] |
| 860 | + }, |
| 861 | + { |
| 862 | + "role": "assistant", |
| 863 | + "tool_calls": [{ |
| 864 | + "id": "call_123", |
| 865 | + "type": "function", |
| 866 | + "function": { |
| 867 | + "name": "analyze_video", |
| 868 | + "arguments": "{\"url\":\"https://example.com/vid.mp4\",\"format\":\"mp4\"}" |
| 869 | + } |
| 870 | + }] |
| 871 | + } |
| 872 | + ] |
| 873 | + }"#; |
| 874 | + |
| 875 | + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); |
| 876 | + let mut messages = serde_json::to_value(request.messages()).unwrap(); |
| 877 | + |
| 878 | + normalize_tool_arguments_in_messages(&mut messages); |
| 879 | + |
| 880 | + // Multimodal content preserved as array |
| 881 | + assert!(messages[0]["content"].is_array()); |
| 882 | + assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3); |
| 883 | + |
| 884 | + // Tool arguments deserialized to object |
| 885 | + assert!(messages[1]["tool_calls"][0]["function"]["arguments"].is_object()); |
| 886 | + assert_eq!( |
| 887 | + messages[1]["tool_calls"][0]["function"]["arguments"]["url"], |
| 888 | + "https://example.com/vid.mp4" |
| 889 | + ); |
| 890 | + } |
| 891 | + |
708 | 892 | fn user() -> Msg { |
709 | 893 | Msg::User(Default::default()) |
710 | 894 | } |
|
0 commit comments