Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,16 @@ client<llm> InvalidCompositeFields {
options {
strategy [BaseClient]
http {
request_timeout_ms 5000 // Not allowed for composite
connect_timeout_ms 3000 // Not allowed for composite
invalid_field_name 5000 // Unrecognized field
}
}
}

// error: Unrecognized field 'request_timeout_ms' in http configuration block. Did you mean 'total_timeout_ms'? Composite clients (fallback/round-robin) only support: total_timeout_ms
// error: Unrecognized field 'invalid_field_name' in http configuration block. Composite clients (fallback/round-robin) support: connect_timeout_ms, request_timeout_ms, time_to_first_token_timeout_ms, idle_timeout_ms, total_timeout_ms
// --> client/http_config_composite_wrong_field.baml:12
// |
// 11 | strategy [BaseClient]
// 12 | http {
// 13 | request_timeout_ms 5000 // Not allowed for composite
// 14 | connect_timeout_ms 3000 // Not allowed for composite
// 15 | }
// |
// error: Unrecognized field 'connect_timeout_ms' in http configuration block. Did you mean 'total_timeout_ms'? Composite clients (fallback/round-robin) only support: total_timeout_ms
// --> client/http_config_composite_wrong_field.baml:12
// |
// 11 | strategy [BaseClient]
// 12 | http {
// 13 | request_timeout_ms 5000 // Not allowed for composite
// 14 | connect_timeout_ms 3000 // Not allowed for composite
// 15 | }
// 13 | invalid_field_name 5000 // Unrecognized field
// 14 | }
// |
44 changes: 30 additions & 14 deletions engine/baml-lib/llm-client/src/clients/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,18 @@ impl<Meta: Clone> PropertyHandler<Meta> {
// Define allowed fields based on provider type
let is_composite =
provider_type == "fallback" || provider_type == "round-robin";
// All timeouts are now allowed on both regular and composite clients
// Composite clients additionally support total_timeout_ms
let allowed_fields: HashSet<&str> = if is_composite {
// Composite clients only support total_timeout_ms
vec!["total_timeout_ms"].into_iter().collect()
vec![
"connect_timeout_ms",
"request_timeout_ms",
"time_to_first_token_timeout_ms",
"idle_timeout_ms",
"total_timeout_ms",
]
.into_iter()
.collect()
} else {
// Regular clients support all timeout types except total_timeout_ms
vec![
Expand All @@ -532,7 +541,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {

for (key, (_, value)) in config_map {
match key.as_str() {
"connect_timeout_ms" if !is_composite => {
"connect_timeout_ms" => {
let value_meta = value.meta().clone();
match value.into_numeric() {
Ok((val_str, _)) => {
Expand All @@ -556,7 +565,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {
}
}
}
"request_timeout_ms" if !is_composite => {
"request_timeout_ms" => {
let value_meta = value.meta().clone();
match value.into_numeric() {
Ok((val_str, _)) => {
Expand All @@ -580,7 +589,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {
}
}
}
"time_to_first_token_timeout_ms" if !is_composite => {
"time_to_first_token_timeout_ms" => {
let value_meta = value.meta().clone();
match value.into_numeric() {
Ok((val_str, _)) => {
Expand All @@ -603,7 +612,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {
}
}
}
"idle_timeout_ms" if !is_composite => {
"idle_timeout_ms" => {
let value_meta = value.meta().clone();
match value.into_numeric() {
Ok((val_str, _)) => {
Expand Down Expand Up @@ -665,21 +674,28 @@ impl<Meta: Clone> PropertyHandler<Meta> {
// Build error messages with suggestions
for unrecognized_field in &unrecognized_fields {
let error_msg = if is_composite {
// For composite clients
if unrecognized_field == "total_timeout_ms" {
// This shouldn't happen as it's in the allowed list for composites
continue;
} else if let Some(suggestion) =
find_best_match(unrecognized_field, &["total_timeout_ms"])
// For composite clients - all timeouts are supported
let all_timeout_fields = vec![
"connect_timeout_ms",
"request_timeout_ms",
"time_to_first_token_timeout_ms",
"idle_timeout_ms",
"total_timeout_ms",
];

if let Some(suggestion) =
find_best_match(unrecognized_field, &all_timeout_fields)
{
format!(
"Unrecognized field '{unrecognized_field}' in http configuration block. Did you mean '{suggestion}'? \
Composite clients (fallback/round-robin) only support: total_timeout_ms"
Composite clients (fallback/round-robin) support: connect_timeout_ms, request_timeout_ms, \
time_to_first_token_timeout_ms, idle_timeout_ms, total_timeout_ms"
)
} else {
format!(
"Unrecognized field '{unrecognized_field}' in http configuration block. \
Composite clients (fallback/round-robin) only support: total_timeout_ms"
Composite clients (fallback/round-robin) support: connect_timeout_ms, request_timeout_ms, \
time_to_first_token_timeout_ms, idle_timeout_ms, total_timeout_ms"
)
}
} else {
Expand Down
91 changes: 91 additions & 0 deletions engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ pub async fn orchestrate(
let mut results = Vec::new();
let mut total_sleep_duration = std::time::Duration::from_secs(0);

// Extract total_timeout_ms from strategy if present
let total_timeout_ms: Option<u64> = iter.first().and_then(|node| {
node.scope.scope.iter().find_map(|scope| match scope {
super::ExecutionScope::Fallback(strategy, _) => strategy.http_config.total_timeout_ms,
super::ExecutionScope::RoundRobin(strategy, _) => strategy.http_config.total_timeout_ms,
_ => None,
})
});

// Track the start time for total timeout
let start_time = web_time::Instant::now();

// Create a future that either waits for cancellation or never completes
let cancel_future = match cancel_tripwire {
Some(tripwire) => Box::pin(async move {
Expand All @@ -73,8 +85,59 @@ pub async fn orchestrate(
tokio::pin!(cancel_future);

for node in iter {
// Check for total timeout before starting each client
if let Some(timeout_ms) = total_timeout_ms {
let elapsed = start_time.elapsed();
if elapsed.as_millis() >= timeout_ms as u128 {
let cancel_scope = node.scope.clone();
results.push((
cancel_scope,
LLMResponse::LLMFailure(crate::internal::llm_client::LLMErrorResponse {
client: node.provider.name().to_string(),
model: None,
message: format!("Total timeout of {}ms exceeded", timeout_ms),
code: crate::internal::llm_client::ErrorCode::Timeout,
prompt: internal_baml_jinja::RenderedPrompt::Completion(String::new()),
start_time: web_time::SystemTime::now(),
latency: elapsed,
request_options: Default::default(),
}),
Some(Err(anyhow::anyhow!(
crate::errors::ExposedError::TimeoutError {
client_name: node.provider.name().to_string(),
message: format!(
"Total timeout of {}ms exceeded (elapsed: {}ms)",
timeout_ms,
elapsed.as_millis()
),
}
))),
));
break;
}
}

// Check for cancellation at the start of each iteration
let cancel_scope = node.scope.clone();

// Clone data needed for timeout error before moving node
let client_name_for_timeout = node.provider.name().to_string();

// Create a timeout future if total_timeout_ms is set
let timeout_future = if let Some(timeout_ms) = total_timeout_ms {
let remaining_time = timeout_ms.saturating_sub(start_time.elapsed().as_millis() as u64);
if remaining_time == 0 {
// Already exceeded, will be caught by the check above
Box::pin(async_std::task::sleep(std::time::Duration::from_millis(0)))
as std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
} else {
Box::pin(async_std::task::sleep(std::time::Duration::from_millis(remaining_time)))
}
} else {
Box::pin(futures::future::pending())
};
tokio::pin!(timeout_future);

tokio::select! {
biased;

Expand All @@ -90,6 +153,34 @@ pub async fn orchestrate(
));
break;
}
_ = &mut timeout_future => {
// Total timeout exceeded during client execution
let elapsed = start_time.elapsed();
results.push((
cancel_scope,
LLMResponse::LLMFailure(crate::internal::llm_client::LLMErrorResponse {
client: client_name_for_timeout.clone(),
model: None,
message: format!("Total timeout of {}ms exceeded", total_timeout_ms.unwrap()),
code: crate::internal::llm_client::ErrorCode::Timeout,
prompt: internal_baml_jinja::RenderedPrompt::Completion(String::new()),
start_time: web_time::SystemTime::now(),
latency: elapsed,
request_options: Default::default(),
}),
Some(Err(anyhow::anyhow!(
crate::errors::ExposedError::TimeoutError {
client_name: client_name_for_timeout,
message: format!(
"Total timeout of {}ms exceeded (elapsed: {}ms)",
total_timeout_ms.unwrap(),
elapsed.as_millis()
),
}
))),
));
break;
}
result = async {
let prompt = match node.render_prompt(ir, prompt, ctx, params).await {
Ok(p) => p,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl std::fmt::Display for ExecutionScope {
write!(f, "RoundRobin({}, {})", strategy.name, index)
}
ExecutionScope::Fallback(strategy, index) => {
write!(f, "Fallback({strategy}, {index})")
write!(f, "Fallback({}, {})", strategy.name, index)
}
}
}
Expand Down Expand Up @@ -152,8 +152,8 @@ pub enum ExecutionScope {
Retry(String, usize, Duration),
// StrategyName, ClientIndex
RoundRobin(Arc<RoundRobinStrategy>, usize),
// StrategyName, ClientIndex
Fallback(String, usize),
// Strategy (with http_config), ClientIndex
Fallback(Arc<super::strategy::fallback::FallbackStrategy>, usize),
}

pub type OrchestratorNodeIterator = Vec<OrchestratorNode>;
Expand Down
91 changes: 91 additions & 0 deletions engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,18 @@ where
let mut results = Vec::new();
let mut total_sleep_duration = web_time::Duration::from_secs(0);

// Extract total_timeout_ms from strategy if present
let total_timeout_ms: Option<u64> = iter.first().and_then(|node| {
node.scope.scope.iter().find_map(|scope| match scope {
ExecutionScope::Fallback(strategy, _) => strategy.http_config.total_timeout_ms,
ExecutionScope::RoundRobin(strategy, _) => strategy.http_config.total_timeout_ms,
_ => None,
})
});

// Track the start time for total timeout
let start_time = web_time::Instant::now();

// Create a future that either waits for cancellation or never completes
let cancel_future = match cancel_tripwire {
Some(tripwire) => Box::pin(async move {
Expand All @@ -184,8 +196,59 @@ where

//advanced curl viewing, use render_raw_curl on each node. TODO
for node in iter {
// Check for total timeout before starting each client
if let Some(timeout_ms) = total_timeout_ms {
let elapsed = start_time.elapsed();
if elapsed.as_millis() >= timeout_ms as u128 {
let cancel_scope = node.scope.clone();
results.push((
cancel_scope,
LLMResponse::LLMFailure(crate::internal::llm_client::LLMErrorResponse {
client: node.provider.name().to_string(),
model: None,
message: format!("Total timeout of {}ms exceeded", timeout_ms),
code: crate::internal::llm_client::ErrorCode::Timeout,
prompt: internal_baml_jinja::RenderedPrompt::Completion(String::new()),
start_time: web_time::SystemTime::now(),
latency: elapsed,
request_options: Default::default(),
}),
Some(Err(anyhow::anyhow!(
crate::errors::ExposedError::TimeoutError {
client_name: node.provider.name().to_string(),
message: format!(
"Total timeout of {}ms exceeded (elapsed: {}ms)",
timeout_ms,
elapsed.as_millis()
),
}
))),
));
break;
}
}

// Check for cancellation at the start of each iteration
let cancel_scope = node.scope.clone();

// Clone data needed for timeout error before moving node
let client_name_for_timeout = node.provider.name().to_string();

// Create a timeout future if total_timeout_ms is set
let timeout_future = if let Some(timeout_ms) = total_timeout_ms {
let remaining_time = timeout_ms.saturating_sub(start_time.elapsed().as_millis() as u64);
if remaining_time == 0 {
// Already exceeded, will be caught by the check above
Box::pin(async_std::task::sleep(std::time::Duration::from_millis(0)))
as std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
} else {
Box::pin(async_std::task::sleep(std::time::Duration::from_millis(remaining_time)))
}
} else {
Box::pin(futures::future::pending())
};
tokio::pin!(timeout_future);

tokio::select! {
biased;

Expand All @@ -201,6 +264,34 @@ where
));
break;
}
_ = &mut timeout_future => {
// Total timeout exceeded during client execution
let elapsed = start_time.elapsed();
results.push((
cancel_scope,
LLMResponse::LLMFailure(crate::internal::llm_client::LLMErrorResponse {
client: client_name_for_timeout.clone(),
model: None,
message: format!("Total timeout of {}ms exceeded", total_timeout_ms.unwrap()),
code: crate::internal::llm_client::ErrorCode::Timeout,
prompt: internal_baml_jinja::RenderedPrompt::Completion(String::new()),
start_time: web_time::SystemTime::now(),
latency: elapsed,
request_options: Default::default(),
}),
Some(Err(anyhow::anyhow!(
crate::errors::ExposedError::TimeoutError {
client_name: client_name_for_timeout,
message: format!(
"Total timeout of {}ms exceeded (elapsed: {}ms)",
total_timeout_ms.unwrap(),
elapsed.as_millis()
),
}
))),
));
break;
}
result = async {
let prompt = match node.render_prompt(ir, prompt, ctx, params).await {
Ok(p) => p,
Expand Down
Loading
Loading