|
4 | 4 | use super::*; |
5 | 5 |
|
6 | 6 | use minijinja::{context, value::Value}; |
| 7 | +use std::result::Result::Ok; |
7 | 8 |
|
8 | 9 | use crate::protocols::openai::{ |
9 | 10 | chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest, |
@@ -131,26 +132,22 @@ fn normalize_tool_arguments_in_messages(messages: &mut serde_json::Value) { |
131 | 132 | for msg in msgs.iter_mut() { |
132 | 133 | if let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|v| v.as_array_mut()) { |
133 | 134 | for tc in tool_calls { |
134 | | - if let Some(function) = tc.get_mut("function").and_then(|v| v.as_object_mut()) { |
135 | | - if let Some(args) = function.get_mut("arguments") { |
136 | | - if let Some(s) = args.as_str() { |
137 | | - if let Result::Ok(parsed) = serde_json::from_str(s) { |
138 | | - *args = parsed; |
139 | | - } |
140 | | - } |
141 | | - } |
| 135 | + if let Some(function) = tc.get_mut("function").and_then(|v| v.as_object_mut()) |
| 136 | + && let Some(args) = function.get_mut("arguments") |
| 137 | + && let Some(s) = args.as_str() |
| 138 | + && let Ok(parsed) = serde_json::from_str(s) |
| 139 | + { |
| 140 | + *args = parsed; |
142 | 141 | } |
143 | 142 | } |
144 | 143 | } |
145 | 144 |
|
146 | | - if let Some(function_call) = msg.get_mut("function_call").and_then(|v| v.as_object_mut()) { |
147 | | - if let Some(args) = function_call.get_mut("arguments") { |
148 | | - if let Some(s) = args.as_str() { |
149 | | - if let Result::Ok(parsed) = serde_json::from_str(s) { |
150 | | - *args = parsed; |
151 | | - } |
152 | | - } |
153 | | - } |
| 145 | + if let Some(function_call) = msg.get_mut("function_call").and_then(|v| v.as_object_mut()) |
| 146 | + && let Some(args) = function_call.get_mut("arguments") |
| 147 | + && let Some(s) = args.as_str() |
| 148 | + && let Ok(parsed) = serde_json::from_str(s) |
| 149 | + { |
| 150 | + *args = parsed; |
154 | 151 | } |
155 | 152 | } |
156 | 153 | } |
@@ -800,7 +797,6 @@ NORMAL MODE |
800 | 797 | .render(context! { messages => messages.as_array().unwrap() }) |
801 | 798 | .unwrap(); |
802 | 799 |
|
803 | | - // Order-insensitive check: either a=1;b=x; or b=x;a=1; |
804 | 800 | assert!(out == "a=1;b=x;" || out == "b=x;a=1;"); |
805 | 801 | } |
806 | 802 |
|
@@ -839,7 +835,6 @@ NORMAL MODE |
839 | 835 |
|
840 | 836 | normalize_tool_arguments_in_messages(&mut messages); |
841 | 837 |
|
842 | | - // Should remain as string |
843 | 838 | assert_eq!( |
844 | 839 | messages[0]["tool_calls"][0]["function"]["arguments"], |
845 | 840 | serde_json::Value::String("not valid json at all".to_string()) |
|
0 commit comments