diff --git a/README.md b/README.md index f47132d..125fbfa 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ See `examples/` for more runnable examples. ## Known Issues -- .. +- **TODO**: Evaluate whether `ToolCallingConfig` should be required rather than optional. Currently providers default to a guard with sensible limits when no config is set, but allowing `None` suggests users can opt out of loop protection entirely. We may want to be opinionated here and always require a `ToolCallingConfig`. ## License diff --git a/macros/src/common.rs b/macros/src/common.rs new file mode 100644 index 0000000..61b42e2 --- /dev/null +++ b/macros/src/common.rs @@ -0,0 +1,12 @@ +/// Convert a snake_case name to PascalCase. +pub fn to_pascal_case(name: &str) -> String { + name.split('_') + .map(|s| { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } + }) + .collect() +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 03d711a..4a92cc1 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -37,6 +37,7 @@ use proc_macro::TokenStream; use quote::quote; +mod common; mod tool; mod tools; diff --git a/macros/src/tool.rs b/macros/src/tool.rs index da7c96e..cd5b8c3 100644 --- a/macros/src/tool.rs +++ b/macros/src/tool.rs @@ -31,17 +31,7 @@ pub fn tool_impl(attr: TokenStream, item: TokenStream) -> Result { // Generate the wrapper struct name let wrapper_name = quote::format_ident!( "{}Tool", - fn_name - .to_string() - .split('_') - .map(|s| { - let mut c = s.chars(); - match c.next() { - None => String::new(), - Some(f) => f.to_uppercase().collect::() + c.as_str(), - } - }) - .collect::() + crate::common::to_pascal_case(&fn_name.to_string()) ); // Check if function is async diff --git a/macros/src/tools.rs b/macros/src/tools.rs index 0e6e141..94bcb28 100644 --- a/macros/src/tools.rs +++ b/macros/src/tools.rs @@ -46,19 +46,6 @@ impl Parse for ToolsList { } } -/// Convert a snake_case function name to PascalCase struct name -fn to_pascal_case(name: &str) -> String { - name.split('_') - .map(|s| { - let mut c = s.chars(); - match c.next() { - None => String::new(), - Some(f) => f.to_uppercase().collect::() + c.as_str(), - } - }) - .collect() -} - pub fn tools_impl(input: TokenStream) -> Result { let tools_list = syn::parse2::(input)?; @@ -73,7 +60,12 @@ pub fn tools_impl(input: TokenStream) -> Result { let wrapper_names: Vec<_> = tools_list .tools .iter() - .map(|tool_name| quote::format_ident!("{}Tool", to_pascal_case(&tool_name.to_string()))) + .map(|tool_name| { + quote::format_ident!( + "{}Tool", + crate::common::to_pascal_case(&tool_name.to_string()) + ) + }) .collect(); // Generate different code based on whether context is present diff --git a/src/completions/client.rs b/src/completions/client.rs index 239df0c..b0754e7 100644 --- a/src/completions/client.rs +++ b/src/completions/client.rs @@ -211,7 +211,7 @@ impl CompletionClient

{ result: result.clone(), }); - // If not parallel, process one at a time + // In sequential mode process one call per model turn. if !is_parallel { break; } @@ -225,7 +225,7 @@ impl CompletionClient

{ } /// Convert core messages to conversation items. -fn convert_messages_to_conversation( +pub(crate) fn convert_messages_to_conversation( messages: &[crate::core::ConversationMessage], ) -> Result, LlmError> { messages diff --git a/src/core/builder.rs b/src/core/builder.rs index 7eb5d74..ec85e54 100644 --- a/src/core/builder.rs +++ b/src/core/builder.rs @@ -402,84 +402,44 @@ impl LlmBuilder = messages + .into_iter() + .map(ConversationMessage::Chat) + .collect(); + + let req = StructuredRequest { + model: model_string, + messages: conversation_messages, + tool_config: tool_schemas.map(|tools| ToolConfig { + tools: Some(tools), + tool_choice: self.fields.tool_choice.clone(), + parallel_tool_calls: self.fields.parallel_tool_calls, + }), + generation_config: Some(GenerationConfig { + max_tokens: self.fields.max_tokens, + temperature: self.fields.temperature, + top_p: self.fields.top_p, + }), + }; + + let tool_registry = self.fields.tool_registry.as_ref(); match provider { Provider::OpenAI => { - let conversation_messages: Vec = messages - .into_iter() - .map(ConversationMessage::Chat) - .collect(); - - let req = StructuredRequest { - model: model_string, - messages: conversation_messages, - tool_config: tool_schemas.map(|tools| ToolConfig { - tools: Some(tools), - tool_choice: self.fields.tool_choice.clone(), - parallel_tool_calls: self.fields.parallel_tool_calls, - }), - generation_config: Some(GenerationConfig { - max_tokens: self.fields.max_tokens, - temperature: self.fields.temperature, - top_p: self.fields.top_p, - }), - }; let client = openai::create_openai_client_from_builder(&self)?; client - .generate_completion::( - req, - format.clone(), - self.fields.tool_registry.as_ref(), - ) + .generate_completion::(req, format, tool_registry) .await } Provider::OpenRouter => { - let conversation_messages: Vec = messages - .into_iter() - .map(ConversationMessage::Chat) - .collect(); - - let req = StructuredRequest { - model: model_string, - messages: conversation_messages, - tool_config: tool_schemas.map(|tools| ToolConfig { - tools: Some(tools), - tool_choice: self.fields.tool_choice.clone(), - parallel_tool_calls: self.fields.parallel_tool_calls, - }), - generation_config: Some(GenerationConfig { - max_tokens: self.fields.max_tokens, - temperature: self.fields.temperature, - top_p: self.fields.top_p, - }), - }; let client = openrouter::create_openrouter_client_from_builder(&self)?; client - .generate_completion::(req, format, self.fields.tool_registry.as_ref()) + .generate_completion::(req, format, tool_registry) .await } Provider::Gemini => { - let conversation_messages: Vec = messages - .into_iter() - .map(ConversationMessage::Chat) - .collect(); - - let req = StructuredRequest { - model: model_string, - messages: conversation_messages, - tool_config: tool_schemas.map(|tools| ToolConfig { - tools: Some(tools), - tool_choice: self.fields.tool_choice.clone(), - parallel_tool_calls: self.fields.parallel_tool_calls, - }), - generation_config: Some(GenerationConfig { - max_tokens: self.fields.max_tokens, - temperature: self.fields.temperature, - top_p: self.fields.top_p, - }), - }; let client = gemini::create_gemini_client_from_builder(&self)?; client - .generate_completion::(req, format, self.fields.tool_registry.as_ref()) + .generate_completion::(req, format, tool_registry) .await } } diff --git a/src/core/error.rs b/src/core/error.rs index 654f610..ac9a77f 100644 --- a/src/core/error.rs +++ b/src/core/error.rs @@ -56,6 +56,6 @@ pub enum LlmError { #[error("Tool call processing timeout exceeded: {timeout:?}")] ToolCallTimeout { timeout: std::time::Duration }, - #[error("Toll registration failed for {tool_name}: {message}")] + #[error("Tool registration failed for {tool_name}: {message}")] ToolRegistration { tool_name: String, message: String }, } diff --git a/src/core/http.rs b/src/core/http.rs index d8d043a..ef6958c 100644 --- a/src/core/http.rs +++ b/src/core/http.rs @@ -82,16 +82,14 @@ impl HttpClient { Req: Serialize, Res: DeserializeOwned, { - // Serialize request to Value for inspection - let body_value = serde_json::to_value(body).map_err(|e| LlmError::Parse { - message: "Failed to serialize request for inspection".to_string(), - source: Box::new(e), - })?; - - // Call request inspector + // Only serialize to Value if we need to inspect the request if let Some(ref config) = self.inspector_config && let Some(ref inspector) = config.request_inspector { + let body_value = serde_json::to_value(body).map_err(|e| LlmError::Parse { + message: "Failed to serialize request for inspection".to_string(), + source: Box::new(e), + })?; inspector(&body_value); } @@ -99,7 +97,7 @@ impl HttpClient { for attempt in 0..=self.config.max_retries { // Build request (must be rebuilt each attempt since .send() consumes it) - let mut req_builder = self.client.post(url).json(&body_value); + let mut req_builder = self.client.post(url).json(body); // Add headers for (name, value) in headers { @@ -125,31 +123,34 @@ impl HttpClient { if status.is_success() { debug!(status = %status, "HTTP request successful"); - // Parse response to text first, then to Value for inspection let response_text = res.text().await.map_err(|e| LlmError::Parse { message: "Failed to read response body".to_string(), source: Box::new(e), })?; - let response_value: serde_json::Value = - serde_json::from_str(&response_text).map_err(|e| LlmError::Parse { - message: "Failed to parse response as JSON".to_string(), - source: Box::new(e), - })?; - - // Call response inspector + // Only go through intermediate Value if we need to inspect if let Some(ref config) = self.inspector_config && let Some(ref inspector) = config.response_inspector { + let response_value: serde_json::Value = + serde_json::from_str(&response_text).map_err(|e| { + LlmError::Parse { + message: "Failed to parse response as JSON".to_string(), + source: Box::new(e), + } + })?; inspector(&response_value); + return serde_json::from_value(response_value).map_err(|e| { + LlmError::Parse { + message: "Failed to parse API response".to_string(), + source: Box::new(e), + } + }); } - // Deserialize to target type - return serde_json::from_value(response_value).map_err(|e| { - LlmError::Parse { - message: "Failed to parse API response".to_string(), - source: Box::new(e), - } + return serde_json::from_str(&response_text).map_err(|e| LlmError::Parse { + message: "Failed to parse API response".to_string(), + source: Box::new(e), }); } diff --git a/src/core/tool_guard.rs b/src/core/tool_guard.rs index 4b0fe47..d969356 100644 --- a/src/core/tool_guard.rs +++ b/src/core/tool_guard.rs @@ -61,11 +61,7 @@ impl ToolCallingGuard { /// Create a new ToolCallingGuard from a config pub fn from_config(config: &ToolCallingConfig) -> Self { - Self { - max_iterations: config.max_iterations, - timeout: config.timeout, - current_iteration: 0, - } + Self::with_limits(config.max_iterations, config.timeout) } /// Increment iteration count and check if limit is exceeded diff --git a/src/provider/gemini.rs b/src/provider/gemini.rs index e9302a7..422b499 100644 --- a/src/provider/gemini.rs +++ b/src/provider/gemini.rs @@ -342,16 +342,7 @@ impl CompletionRequestBuilder for GeminiRequestBuilder { let candidate = response.candidates.as_ref()?.first()?; let content = candidate.content.as_ref()?; - let mut calls = Vec::new(); - for (idx, part) in content.parts.iter().enumerate() { - if let Part::FunctionCall(FunctionCallPart { function_call }) = part { - calls.push(FunctionCallData { - id: format!("call_{}", idx), - name: function_call.name.clone(), - arguments: function_call.args.clone(), - }); - } - } + let calls: Vec = extract_function_calls_from_parts(&content.parts); if calls.is_empty() { None } else { Some(calls) } } @@ -565,28 +556,42 @@ fn build_tools_config( ) } -fn parse_parts_to_content(parts: &[Part]) -> Result { - let mut text_parts = Vec::new(); - let mut function_calls = Vec::new(); - - for (idx, part) in parts.iter().enumerate() { - match part { - Part::Text(TextPart { text }) => text_parts.push(text.clone()), - Part::FunctionCall(FunctionCallPart { function_call }) => { - function_calls.push(FunctionCallData { - id: format!("call_{}", idx), +/// Extract function calls from Gemini response parts with unique IDs. +fn extract_function_calls_from_parts(parts: &[Part]) -> Vec { + parts + .iter() + .enumerate() + .filter_map(|(idx, part)| { + if let Part::FunctionCall(FunctionCallPart { function_call }) = part { + Some(FunctionCallData { + id: format!("call_{}_{:08x}", idx, rand::random::()), name: function_call.name.clone(), arguments: function_call.args.clone(), - }); + }) + } else { + None } - Part::FunctionResponse(_) => {} - } - } + }) + .collect() +} +fn parse_parts_to_content(parts: &[Part]) -> Result { + let function_calls = extract_function_calls_from_parts(parts); if !function_calls.is_empty() { - Ok(ResponseContent::FunctionCalls(function_calls)) - } else if !text_parts.is_empty() { - Ok(ResponseContent::Text(text_parts.join(""))) + return Ok(ResponseContent::FunctionCalls(function_calls)); + } + + let text: String = parts + .iter() + .filter_map(|p| match p { + Part::Text(TextPart { text }) => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join(""); + + if !text.is_empty() { + Ok(ResponseContent::Text(text)) } else { Err(LlmError::Provider { message: "Empty response from Gemini".to_string(), @@ -601,34 +606,24 @@ fn parse_parts_to_content(parts: &[Part]) -> Result { pub struct GeminiClient { completion_client: CompletionClient, - config: GeminiConfig, } impl GeminiClient { pub fn new(api_key: String) -> Result { - let config = GeminiConfig::new(api_key.clone()); - let completion_client = CompletionClient::new(GeminiConfig::new(api_key))?; - + let config = GeminiConfig::new(api_key); Ok(Self { - completion_client, - config, + completion_client: CompletionClient::new(config)?, }) } pub fn with_base_url(mut self, base_url: String) -> Result { + let config = &self.completion_client.config; let new_config = GeminiConfig { - api_key: self.config.api_key.clone(), - base_url: base_url.clone(), - tool_calling_config: self.config.tool_calling_config.clone(), - http_config: self.config.http_config.clone(), - inspector_config: self.config.inspector_config.clone(), - }; - self.config = GeminiConfig { - api_key: self.config.api_key.clone(), + api_key: config.api_key.clone(), base_url, - tool_calling_config: self.config.tool_calling_config.clone(), - http_config: self.config.http_config.clone(), - inspector_config: self.config.inspector_config.clone(), + tool_calling_config: config.tool_calling_config.clone(), + http_config: config.http_config.clone(), + inspector_config: config.inspector_config.clone(), }; self.completion_client = CompletionClient::new(new_config)?; Ok(self) @@ -638,27 +633,27 @@ impl GeminiClient { mut self, tool_config: ToolCallingConfig, ) -> Result { + let config = &self.completion_client.config; let new_config = GeminiConfig { - api_key: self.config.api_key.clone(), - base_url: self.config.base_url.clone(), - tool_calling_config: Some(tool_config.clone()), - http_config: self.config.http_config.clone(), - inspector_config: self.config.inspector_config.clone(), + api_key: config.api_key.clone(), + base_url: config.base_url.clone(), + tool_calling_config: Some(tool_config), + http_config: config.http_config.clone(), + inspector_config: config.inspector_config.clone(), }; - self.config.tool_calling_config = Some(tool_config); self.completion_client = CompletionClient::new(new_config)?; Ok(self) } pub fn with_http_config(mut self, http_config: HttpClientConfig) -> Result { + let config = &self.completion_client.config; let new_config = GeminiConfig { - api_key: self.config.api_key.clone(), - base_url: self.config.base_url.clone(), - tool_calling_config: self.config.tool_calling_config.clone(), - http_config: http_config.clone(), - inspector_config: self.config.inspector_config.clone(), + api_key: config.api_key.clone(), + base_url: config.base_url.clone(), + tool_calling_config: config.tool_calling_config.clone(), + http_config, + inspector_config: config.inspector_config.clone(), }; - self.config.http_config = http_config; self.completion_client = CompletionClient::new(new_config)?; Ok(self) } @@ -667,14 +662,14 @@ impl GeminiClient { mut self, inspector_config: InspectorConfig, ) -> Result { + let config = &self.completion_client.config; let new_config = GeminiConfig { - api_key: self.config.api_key.clone(), - base_url: self.config.base_url.clone(), - tool_calling_config: self.config.tool_calling_config.clone(), - http_config: self.config.http_config.clone(), - inspector_config: Some(inspector_config.clone()), + api_key: config.api_key.clone(), + base_url: config.base_url.clone(), + tool_calling_config: config.tool_calling_config.clone(), + http_config: config.http_config.clone(), + inspector_config: Some(inspector_config), }; - self.config.inspector_config = Some(inspector_config); self.completion_client = CompletionClient::new(new_config)?; Ok(self) } @@ -701,14 +696,14 @@ impl LlmProvider for GeminiClient { .and_then(|tc| tc.tools.as_ref()) .is_some(); - if has_tools && tool_registry.is_some() { - let mut guard = self.config.get_tool_calling_guard(); + if has_tools && let Some(tool_registry) = tool_registry { + let mut guard = self.completion_client.config.get_tool_calling_guard(); let provider_response = self .completion_client .handle_tool_calling_loop::<_, Ctx>( &builder, request, - tool_registry.unwrap(), + tool_registry, &mut guard, format, ) @@ -717,7 +712,8 @@ impl LlmProvider for GeminiClient { } // Single request without tool calling loop - let conversation = convert_messages_to_conversation(&request.messages)?; + let conversation = + crate::completions::client::convert_messages_to_conversation(&request.messages)?; let api_request = builder.build_request(&request, &format, &conversation)?; let api_response = self .completion_client @@ -728,38 +724,6 @@ impl LlmProvider for GeminiClient { } } -fn convert_messages_to_conversation( - messages: &[crate::core::ConversationMessage], -) -> Result, LlmError> { - messages - .iter() - .map(|msg| match msg { - crate::core::ConversationMessage::Chat(m) => { - let role = match m.role { - crate::core::ChatRole::System => "system", - crate::core::ChatRole::User => "user", - crate::core::ChatRole::Assistant => "assistant", - }; - Ok(ConversationItem::Message { - role: role.to_string(), - content: m.content.clone(), - }) - } - crate::core::ConversationMessage::ToolCall(tc) => Ok(ConversationItem::FunctionCall { - id: tc.call_id.clone(), - name: tc.name.clone(), - arguments: tc.arguments.clone(), - }), - crate::core::ConversationMessage::ToolCallResult(tr) => { - Ok(ConversationItem::FunctionResult { - call_id: tr.tool_call_id.clone(), - result: tr.content.clone(), - }) - } - }) - .collect() -} - // ============================================================================ // Builder Integration // ============================================================================ diff --git a/src/provider/openai.rs b/src/provider/openai.rs index 87fc38c..7eb2e42 100644 --- a/src/provider/openai.rs +++ b/src/provider/openai.rs @@ -57,18 +57,17 @@ impl OpenAiConfig { self } - pub fn get_tool_calling_guard(&self) -> ToolCallingGuard { - if let Some(ref config) = self.tool_calling_config { - ToolCallingGuard::with_limits(config.max_iterations, config.timeout) - } else { - ToolCallingGuard::new() - } - } - pub fn with_http_config(mut self, config: HttpClientConfig) -> Self { self.http_config = config; self } + + pub fn get_tool_calling_guard(&self) -> ToolCallingGuard { + match self.tool_calling_config { + Some(ref config) => ToolCallingGuard::from_config(config), + None => ToolCallingGuard::default(), + } + } } impl ResponsesProviderConfig for OpenAiConfig { @@ -176,35 +175,10 @@ impl LlmProvider for OpenAiClient { T: crate::CompletionTarget + Send, Ctx: Send + Sync + 'static, { - // If tools are present and we have a registry, handle automatic tool calling - let has_tools = request - .tool_config - .as_ref() - .and_then(|tc| tc.tools.as_ref()) - .is_some(); - - if has_tools && let Some(tool_registry) = tool_registry { - let mut guard = self.responses_client.config.get_tool_calling_guard(); - return self - .responses_client - .handle_tool_calling_loop::(request, tool_registry, &mut guard, format) - .await; - } - - // Otherwise, make a single request expecting the configured completion output - let messages_clone = request.messages.clone(); - let responses_request = self.responses_client.build_request_with_format( - &request, - &crate::responses::convert_messages_to_responses_format(messages_clone)?, - format, - )?; - let api_response = self - .responses_client - .make_api_request(responses_request) - .await?; - let provider_response = - crate::responses::convert_to_provider_response(api_response, super::Provider::OpenAI)?; - T::parse_response(provider_response) + let guard = self.responses_client.config.get_tool_calling_guard(); + self.responses_client + .generate_completion::(request, format, tool_registry, guard) + .await } } diff --git a/src/provider/openrouter.rs b/src/provider/openrouter.rs index 839ea5e..e869dea 100644 --- a/src/provider/openrouter.rs +++ b/src/provider/openrouter.rs @@ -227,37 +227,10 @@ impl LlmProvider for OpenRouterClient { T: crate::CompletionTarget + Send, Ctx: Send + Sync + 'static, { - // If tools are present and we have a registry, handle automatic tool calling - let has_tools = request - .tool_config - .as_ref() - .and_then(|tc| tc.tools.as_ref()) - .is_some(); - - if has_tools && let Some(tool_registry) = tool_registry { - let mut guard = self.responses_client.config.get_tool_calling_guard(); - return self - .responses_client - .handle_tool_calling_loop::(request, tool_registry, &mut guard, format) - .await; - } - - // Otherwise, make a single request expecting the configured completion output - let messages_clone = request.messages.clone(); - let responses_request = self.responses_client.build_request_with_format( - &request, - &crate::responses::convert_messages_to_responses_format(messages_clone)?, - format, - )?; - let api_response = self - .responses_client - .make_api_request(responses_request) - .await?; - let provider_response = crate::responses::convert_to_provider_response( - api_response, - super::Provider::OpenRouter, - )?; - T::parse_response(provider_response) + let guard = self.responses_client.config.get_tool_calling_guard(); + self.responses_client + .generate_completion::(request, format, tool_registry, guard) + .await } } diff --git a/src/responses/client.rs b/src/responses/client.rs index 92c79c5..e3d5a77 100644 --- a/src/responses/client.rs +++ b/src/responses/client.rs @@ -112,7 +112,6 @@ impl ResponsesClient

{ { let timeout_duration = guard.timeout; - // Use tokio::time::timeout to add timeout protection match tokio::time::timeout( timeout_duration, self.handle_tool_calling_loop_internal::(request, tool_registry, guard, format), @@ -156,7 +155,6 @@ impl ResponsesClient

{ .unwrap_or(true); loop { - // Check iteration limit before processing guard.increment_iteration()?; let iteration_span = @@ -236,7 +234,9 @@ impl ResponsesClient

{ } } - /// Process function calls in parallel (all calls first, then all results) + /// Process function calls in parallel (all calls added first, tools executed concurrently, then all results added). + /// + /// On failure, completed tools may have produced side-effects but no results are added to `responses_input`. pub async fn process_parallel_function_calls( &self, function_calls: &[&FunctionToolCall], @@ -246,33 +246,31 @@ impl ResponsesClient

{ where Ctx: Send + Sync + 'static, { - let mut pending_executions = Vec::new(); - - // Add all function calls to input and prepare for execution + // Add all function calls to input first + let mut tool_calls = Vec::with_capacity(function_calls.len()); for function_call in function_calls { responses_input.push(InputItem::FunctionCall((*function_call).clone())); let arguments = self.parse_function_arguments(&function_call.arguments)?; - pending_executions.push(( - function_call.id.clone(), - function_call.call_id.clone(), - function_call.name.clone(), + tool_calls.push(ToolCall { + id: function_call.id.clone(), + call_id: function_call.call_id.clone(), + name: function_call.name.clone(), arguments, - )); + }); } - // Execute all tools and add their results - for (id, call_id, name, arguments) in pending_executions { - let tool_call = ToolCall { - id, - call_id: call_id.clone(), - name, - arguments, - }; - let result = tool_registry.execute(&tool_call).await?; + // Execute all tools concurrently + let futures: Vec<_> = tool_calls + .iter() + .map(|tc| tool_registry.execute(tc)) + .collect(); + let results = futures::future::try_join_all(futures).await?; + // Add all results in order + for (tool_call, result) in tool_calls.iter().zip(results) { responses_input.push(InputItem::FunctionCallOutput(FunctionToolCallOutput { - call_id, + call_id: tool_call.call_id.clone(), output: result, r#type: "function_call_output".to_string(), })); @@ -314,6 +312,41 @@ impl ResponsesClient

{ Ok(()) } + /// Shared generate_completion for responses-API providers (OpenAI, OpenRouter). + /// + /// Handles the tool calling loop when tools are present, or makes a single + /// request otherwise. + pub async fn generate_completion( + &self, + request: StructuredRequest, + format: crate::responses::request::Format, + tool_registry: Option<&ToolRegistry>, + mut guard: ToolCallingGuard, + ) -> Result + where + T: CompletionTarget + Send, + Ctx: Send + Sync + 'static, + { + let has_tools = request + .tool_config + .as_ref() + .and_then(|tc| tc.tools.as_ref()) + .is_some(); + + if has_tools && let Some(tool_registry) = tool_registry { + return self + .handle_tool_calling_loop::(request, tool_registry, &mut guard, format) + .await; + } + + let responses_input = convert_messages_to_responses_format(request.messages.clone())?; + let responses_request = + self.build_request_with_format(&request, &responses_input, format)?; + let api_response = self.make_api_request(responses_request).await?; + let provider_response = convert_to_provider_response(api_response, self.config.provider())?; + T::parse_response(provider_response) + } + /// Parse function arguments from JSON value pub fn parse_function_arguments( &self, @@ -510,58 +543,52 @@ pub(crate) fn create_text_format() -> Format { } } -/// Convert OpenAI API response to provider-agnostic ProviderResponse +/// Convert OpenAI API response to provider-agnostic ProviderResponse. +/// +/// Aggregates all output items: collects function calls across all items, +/// concatenates text from all messages, and surfaces refusals. +/// Function calls take priority over text if both are present. pub fn convert_to_provider_response( res: Response, provider: crate::provider::Provider, ) -> Result { use crate::core::{FunctionCallData, LanguageModelUsage, ProviderResponse, ResponseContent}; - let output_content = res.output.first().ok_or_else(|| LlmError::Provider { - message: "No output in response".to_string(), - source: None, - })?; - - let content = match output_content { - OutputContent::OutputMessage(message) => { - let msg_content = message.content.first().ok_or_else(|| LlmError::Provider { - message: "No content in message".to_string(), - source: None, - })?; - - match msg_content { - MessageContent::OutputText(output) => ResponseContent::Text(output.text.clone()), - MessageContent::Refusal(refusal) => { - ResponseContent::Refusal(refusal.refusal.clone()) + let mut function_calls = Vec::new(); + let mut text_parts = Vec::new(); + let mut refusal = None; + + for output in &res.output { + match output { + OutputContent::OutputMessage(message) => { + for content in &message.content { + match content { + MessageContent::OutputText(text) => text_parts.push(text.text.clone()), + MessageContent::Refusal(r) => refusal = Some(r.refusal.clone()), + } } } - } - OutputContent::FunctionCall(fc) => { - // Collect all function calls from the output - let function_calls: Vec = res - .output - .iter() - .filter_map(|o| match o { - OutputContent::FunctionCall(fc) => Some(FunctionCallData { - id: fc.call_id.clone(), - name: fc.name.clone(), - arguments: fc.arguments.clone(), - }), - _ => None, - }) - .collect(); - - if function_calls.is_empty() { - // This shouldn't happen since we matched FunctionCall, but handle it - ResponseContent::FunctionCalls(vec![FunctionCallData { + OutputContent::FunctionCall(fc) => { + function_calls.push(FunctionCallData { id: fc.call_id.clone(), name: fc.name.clone(), arguments: fc.arguments.clone(), - }]) - } else { - ResponseContent::FunctionCalls(function_calls) + }); } } + } + + let content = if !function_calls.is_empty() { + ResponseContent::FunctionCalls(function_calls) + } else if let Some(refusal) = refusal { + ResponseContent::Refusal(refusal) + } else if !text_parts.is_empty() { + ResponseContent::Text(text_parts.join("")) + } else { + return Err(LlmError::Provider { + message: "No output in response".to_string(), + source: None, + }); }; Ok(ProviderResponse { diff --git a/src/responses/response.rs b/src/responses/response.rs index d065f04..fe23f3b 100644 --- a/src/responses/response.rs +++ b/src/responses/response.rs @@ -11,9 +11,11 @@ pub struct Response { } #[derive(Debug, Deserialize)] -#[serde(untagged)] +#[serde(tag = "type")] pub enum OutputContent { + #[serde(rename = "message")] OutputMessage(OutputMessage), + #[serde(rename = "function_call")] FunctionCall(FunctionToolCall), } @@ -24,22 +26,12 @@ pub struct Usage { pub total_tokens: i32, } -// TODO: Remove this, once text input is supported #[allow(dead_code)] #[derive(Debug, Deserialize)] pub struct OutputMessage { pub id: String, - - #[allow(dead_code)] - /// This is always `message` - #[serde(rename = "type")] - pub r#type: String, - pub status: Status, - pub content: Vec, - - /// This is always `assistant` pub role: String, } @@ -52,31 +44,71 @@ pub enum Status { } #[derive(Debug, Deserialize)] -#[serde(untagged, rename_all = "snake_case")] +#[serde(tag = "type")] pub enum MessageContent { + #[serde(rename = "output_text")] OutputText(OutputText), + #[serde(rename = "refusal")] Refusal(Refusal), } #[derive(Debug, Deserialize)] pub struct OutputText { - #[allow(dead_code)] - /// Always `output_text` - #[serde(rename = "type")] - pub r#type: String, - pub text: String, - // TODO - // annotations } #[derive(Debug, Deserialize)] pub struct Refusal { - /// The refusal explanation from the model. pub refusal: String, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn tagged_function_call_deserializes() { + let payload = json!({ + "type": "function_call", + "id": "tool_1", + "call_id": "tool_1", + "name": "lookup_weather", + "arguments": "{\"city\":\"Lisbon\"}" + }); + + let parsed: OutputContent = + serde_json::from_value(payload).expect("function_call should deserialize"); + + match parsed { + OutputContent::FunctionCall(call) => { + assert_eq!(call.name, "lookup_weather"); + assert_eq!(call.call_id, "tool_1"); + } + _ => panic!("expected function_call output content"), + } + } + + /// The `type` field is consumed by the tagged enum during deserialization, + /// so `FunctionToolCall.r#type` is populated by its serde default. + /// Verify it still serializes correctly for use as an API input item. + #[test] + fn function_call_round_trips_with_type_field() { + let payload = json!({ + "type": "function_call", + "id": "tool_1", + "call_id": "tool_1", + "name": "lookup_weather", + "arguments": "{\"city\":\"Lisbon\"}" + }); + + let parsed: OutputContent = serde_json::from_value(payload).expect("should deserialize"); + + let OutputContent::FunctionCall(call) = parsed else { + panic!("expected function_call"); + }; - #[allow(dead_code)] - /// Always `refusal` - #[serde(rename = "type")] - pub r#type: String, + let serialized = serde_json::to_value(&call).expect("should serialize"); + assert_eq!(serialized["type"], "function_call"); + } } diff --git a/src/responses/types.rs b/src/responses/types.rs index eef3c9d..a4449fa 100644 --- a/src/responses/types.rs +++ b/src/responses/types.rs @@ -1,8 +1,15 @@ use serde::{Deserialize, Serialize}; +fn default_function_call_type() -> String { + "function_call".to_string() +} + #[derive(Debug, Serialize, Deserialize, Clone)] pub struct FunctionToolCall { - #[serde(rename = "type")] + /// Required by the API when serialized as an input item. + /// The `default` is necessary because `OutputContent`'s `#[serde(tag = "type")]` + /// consumes this field during deserialization, so serde never sees it on the struct. + #[serde(rename = "type", default = "default_function_call_type")] pub r#type: String, pub id: String, pub call_id: String,