Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ See `examples/` for more runnable examples.

## Known Issues

- ..
- **TODO**: Evaluate whether `ToolCallingConfig` should be required rather than optional. Currently providers default to a guard with sensible limits when no config is set, but allowing `None` suggests users can opt out of loop protection entirely. We may want to be opinionated here and always require a `ToolCallingConfig`.

## License

Expand Down
12 changes: 12 additions & 0 deletions macros/src/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/// Convert a snake_case name to PascalCase.
pub fn to_pascal_case(name: &str) -> String {
name.split('_')
.map(|s| {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
})
.collect()
}
1 change: 1 addition & 0 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
use proc_macro::TokenStream;
use quote::quote;

mod common;
mod tool;
mod tools;

Expand Down
12 changes: 1 addition & 11 deletions macros/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,7 @@ pub fn tool_impl(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
// Generate the wrapper struct name
let wrapper_name = quote::format_ident!(
"{}Tool",
fn_name
.to_string()
.split('_')
.map(|s| {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
})
.collect::<String>()
crate::common::to_pascal_case(&fn_name.to_string())
);

// Check if function is async
Expand Down
20 changes: 6 additions & 14 deletions macros/src/tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,6 @@ impl Parse for ToolsList {
}
}

/// Convert a snake_case function name to PascalCase struct name
fn to_pascal_case(name: &str) -> String {
name.split('_')
.map(|s| {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
})
.collect()
}

pub fn tools_impl(input: TokenStream) -> Result<TokenStream> {
let tools_list = syn::parse2::<ToolsList>(input)?;

Expand All @@ -73,7 +60,12 @@ pub fn tools_impl(input: TokenStream) -> Result<TokenStream> {
let wrapper_names: Vec<_> = tools_list
.tools
.iter()
.map(|tool_name| quote::format_ident!("{}Tool", to_pascal_case(&tool_name.to_string())))
.map(|tool_name| {
quote::format_ident!(
"{}Tool",
crate::common::to_pascal_case(&tool_name.to_string())
)
})
.collect();

// Generate different code based on whether context is present
Expand Down
4 changes: 2 additions & 2 deletions src/completions/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ impl<P: CompletionProviderConfig> CompletionClient<P> {
result: result.clone(),
});

// If not parallel, process one at a time
// In sequential mode process one call per model turn.
if !is_parallel {
break;
}
Expand All @@ -225,7 +225,7 @@ impl<P: CompletionProviderConfig> CompletionClient<P> {
}

/// Convert core messages to conversation items.
fn convert_messages_to_conversation(
pub(crate) fn convert_messages_to_conversation(
messages: &[crate::core::ConversationMessage],
) -> Result<Vec<ConversationItem>, LlmError> {
messages
Expand Down
88 changes: 24 additions & 64 deletions src/core/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,84 +402,44 @@ impl<State: private::Completable, Ctx: Send + Sync + 'static> LlmBuilder<State,
None
};

let conversation_messages: Vec<ConversationMessage> = messages
.into_iter()
.map(ConversationMessage::Chat)
.collect();

let req = StructuredRequest {
model: model_string,
messages: conversation_messages,
tool_config: tool_schemas.map(|tools| ToolConfig {
tools: Some(tools),
tool_choice: self.fields.tool_choice.clone(),
parallel_tool_calls: self.fields.parallel_tool_calls,
}),
generation_config: Some(GenerationConfig {
max_tokens: self.fields.max_tokens,
temperature: self.fields.temperature,
top_p: self.fields.top_p,
}),
};

let tool_registry = self.fields.tool_registry.as_ref();
match provider {
Provider::OpenAI => {
let conversation_messages: Vec<ConversationMessage> = messages
.into_iter()
.map(ConversationMessage::Chat)
.collect();

let req = StructuredRequest {
model: model_string,
messages: conversation_messages,
tool_config: tool_schemas.map(|tools| ToolConfig {
tools: Some(tools),
tool_choice: self.fields.tool_choice.clone(),
parallel_tool_calls: self.fields.parallel_tool_calls,
}),
generation_config: Some(GenerationConfig {
max_tokens: self.fields.max_tokens,
temperature: self.fields.temperature,
top_p: self.fields.top_p,
}),
};
let client = openai::create_openai_client_from_builder(&self)?;
client
.generate_completion::<T, Ctx>(
req,
format.clone(),
self.fields.tool_registry.as_ref(),
)
.generate_completion::<T, Ctx>(req, format, tool_registry)
.await
}
Provider::OpenRouter => {
let conversation_messages: Vec<ConversationMessage> = messages
.into_iter()
.map(ConversationMessage::Chat)
.collect();

let req = StructuredRequest {
model: model_string,
messages: conversation_messages,
tool_config: tool_schemas.map(|tools| ToolConfig {
tools: Some(tools),
tool_choice: self.fields.tool_choice.clone(),
parallel_tool_calls: self.fields.parallel_tool_calls,
}),
generation_config: Some(GenerationConfig {
max_tokens: self.fields.max_tokens,
temperature: self.fields.temperature,
top_p: self.fields.top_p,
}),
};
let client = openrouter::create_openrouter_client_from_builder(&self)?;
client
.generate_completion::<T, Ctx>(req, format, self.fields.tool_registry.as_ref())
.generate_completion::<T, Ctx>(req, format, tool_registry)
.await
}
Provider::Gemini => {
let conversation_messages: Vec<ConversationMessage> = messages
.into_iter()
.map(ConversationMessage::Chat)
.collect();

let req = StructuredRequest {
model: model_string,
messages: conversation_messages,
tool_config: tool_schemas.map(|tools| ToolConfig {
tools: Some(tools),
tool_choice: self.fields.tool_choice.clone(),
parallel_tool_calls: self.fields.parallel_tool_calls,
}),
generation_config: Some(GenerationConfig {
max_tokens: self.fields.max_tokens,
temperature: self.fields.temperature,
top_p: self.fields.top_p,
}),
};
let client = gemini::create_gemini_client_from_builder(&self)?;
client
.generate_completion::<T, Ctx>(req, format, self.fields.tool_registry.as_ref())
.generate_completion::<T, Ctx>(req, format, tool_registry)
.await
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/core/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ pub enum LlmError {
#[error("Tool call processing timeout exceeded: {timeout:?}")]
ToolCallTimeout { timeout: std::time::Duration },

#[error("Toll registration failed for {tool_name}: {message}")]
#[error("Tool registration failed for {tool_name}: {message}")]
ToolRegistration { tool_name: String, message: String },
}
45 changes: 23 additions & 22 deletions src/core/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,22 @@ impl HttpClient {
Req: Serialize,
Res: DeserializeOwned,
{
// Serialize request to Value for inspection
let body_value = serde_json::to_value(body).map_err(|e| LlmError::Parse {
message: "Failed to serialize request for inspection".to_string(),
source: Box::new(e),
})?;

// Call request inspector
// Only serialize to Value if we need to inspect the request
if let Some(ref config) = self.inspector_config
&& let Some(ref inspector) = config.request_inspector
{
let body_value = serde_json::to_value(body).map_err(|e| LlmError::Parse {
message: "Failed to serialize request for inspection".to_string(),
source: Box::new(e),
})?;
inspector(&body_value);
}

let mut last_error: Option<LlmError> = None;

for attempt in 0..=self.config.max_retries {
// Build request (must be rebuilt each attempt since .send() consumes it)
let mut req_builder = self.client.post(url).json(&body_value);
let mut req_builder = self.client.post(url).json(body);

// Add headers
for (name, value) in headers {
Expand All @@ -125,31 +123,34 @@ impl HttpClient {
if status.is_success() {
debug!(status = %status, "HTTP request successful");

// Parse response to text first, then to Value for inspection
let response_text = res.text().await.map_err(|e| LlmError::Parse {
message: "Failed to read response body".to_string(),
source: Box::new(e),
})?;

let response_value: serde_json::Value =
serde_json::from_str(&response_text).map_err(|e| LlmError::Parse {
message: "Failed to parse response as JSON".to_string(),
source: Box::new(e),
})?;

// Call response inspector
// Only go through intermediate Value if we need to inspect
if let Some(ref config) = self.inspector_config
&& let Some(ref inspector) = config.response_inspector
{
let response_value: serde_json::Value =
serde_json::from_str(&response_text).map_err(|e| {
LlmError::Parse {
message: "Failed to parse response as JSON".to_string(),
source: Box::new(e),
}
})?;
inspector(&response_value);
return serde_json::from_value(response_value).map_err(|e| {
LlmError::Parse {
message: "Failed to parse API response".to_string(),
source: Box::new(e),
}
});
}

// Deserialize to target type
return serde_json::from_value(response_value).map_err(|e| {
LlmError::Parse {
message: "Failed to parse API response".to_string(),
source: Box::new(e),
}
return serde_json::from_str(&response_text).map_err(|e| LlmError::Parse {
message: "Failed to parse API response".to_string(),
source: Box::new(e),
});
}

Expand Down
6 changes: 1 addition & 5 deletions src/core/tool_guard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ impl ToolCallingGuard {

/// Create a new ToolCallingGuard from a config
pub fn from_config(config: &ToolCallingConfig) -> Self {
Self {
max_iterations: config.max_iterations,
timeout: config.timeout,
current_iteration: 0,
}
Self::with_limits(config.max_iterations, config.timeout)
}

/// Increment iteration count and check if limit is exceeded
Expand Down
Loading