Skip to content

Commit 1ee7e5b

Browse files
committed
compound client timeouts
1 parent 38cfe9b commit 1ee7e5b

File tree

9 files changed

+207
-59
lines changed

9 files changed

+207
-59
lines changed

engine/baml-lib/baml/tests/validation_files/client/http_config_composite_wrong_field.baml

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,16 @@ client<llm> InvalidCompositeFields {
1010
options {
1111
strategy [BaseClient]
1212
http {
13-
request_timeout_ms 5000 // Not allowed for composite
14-
connect_timeout_ms 3000 // Not allowed for composite
13+
invalid_field_name 5000 // Unrecognized field
1514
}
1615
}
1716
}
1817

19-
// error: Unrecognized field 'request_timeout_ms' in http configuration block. Did you mean 'total_timeout_ms'? Composite clients (fallback/round-robin) only support: total_timeout_ms
18+
// error: Unrecognized field 'invalid_field_name' in http configuration block. Composite clients (fallback/round-robin) support: connect_timeout_ms, request_timeout_ms, time_to_first_token_timeout_ms, idle_timeout_ms, total_timeout_ms
2019
// --> client/http_config_composite_wrong_field.baml:12
2120
// |
2221
// 11 | strategy [BaseClient]
2322
// 12 | http {
24-
// 13 | request_timeout_ms 5000 // Not allowed for composite
25-
// 14 | connect_timeout_ms 3000 // Not allowed for composite
26-
// 15 | }
27-
// |
28-
// error: Unrecognized field 'connect_timeout_ms' in http configuration block. Did you mean 'total_timeout_ms'? Composite clients (fallback/round-robin) only support: total_timeout_ms
29-
// --> client/http_config_composite_wrong_field.baml:12
30-
// |
31-
// 11 | strategy [BaseClient]
32-
// 12 | http {
33-
// 13 | request_timeout_ms 5000 // Not allowed for composite
34-
// 14 | connect_timeout_ms 3000 // Not allowed for composite
35-
// 15 | }
23+
// 13 | invalid_field_name 5000 // Unrecognized field
24+
// 14 | }
3625
// |

engine/baml-lib/llm-client/src/clients/helpers.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -515,9 +515,18 @@ impl<Meta: Clone> PropertyHandler<Meta> {
515515
// Define allowed fields based on provider type
516516
let is_composite =
517517
provider_type == "fallback" || provider_type == "round-robin";
518+
// All timeouts are now allowed on both regular and composite clients
519+
// Composite clients additionally support total_timeout_ms
518520
let allowed_fields: HashSet<&str> = if is_composite {
519-
// Composite clients only support total_timeout_ms
520-
vec!["total_timeout_ms"].into_iter().collect()
521+
vec![
522+
"connect_timeout_ms",
523+
"request_timeout_ms",
524+
"time_to_first_token_timeout_ms",
525+
"idle_timeout_ms",
526+
"total_timeout_ms",
527+
]
528+
.into_iter()
529+
.collect()
521530
} else {
522531
// Regular clients support all timeout types except total_timeout_ms
523532
vec![
@@ -532,7 +541,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {
532541

533542
for (key, (_, value)) in config_map {
534543
match key.as_str() {
535-
"connect_timeout_ms" if !is_composite => {
544+
"connect_timeout_ms" => {
536545
let value_meta = value.meta().clone();
537546
match value.into_numeric() {
538547
Ok((val_str, _)) => {
@@ -556,7 +565,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {
556565
}
557566
}
558567
}
559-
"request_timeout_ms" if !is_composite => {
568+
"request_timeout_ms" => {
560569
let value_meta = value.meta().clone();
561570
match value.into_numeric() {
562571
Ok((val_str, _)) => {
@@ -580,7 +589,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {
580589
}
581590
}
582591
}
583-
"time_to_first_token_timeout_ms" if !is_composite => {
592+
"time_to_first_token_timeout_ms" => {
584593
let value_meta = value.meta().clone();
585594
match value.into_numeric() {
586595
Ok((val_str, _)) => {
@@ -603,7 +612,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {
603612
}
604613
}
605614
}
606-
"idle_timeout_ms" if !is_composite => {
615+
"idle_timeout_ms" => {
607616
let value_meta = value.meta().clone();
608617
match value.into_numeric() {
609618
Ok((val_str, _)) => {
@@ -665,21 +674,28 @@ impl<Meta: Clone> PropertyHandler<Meta> {
665674
// Build error messages with suggestions
666675
for unrecognized_field in &unrecognized_fields {
667676
let error_msg = if is_composite {
668-
// For composite clients
669-
if unrecognized_field == "total_timeout_ms" {
670-
// This shouldn't happen as it's in the allowed list for composites
671-
continue;
672-
} else if let Some(suggestion) =
673-
find_best_match(unrecognized_field, &["total_timeout_ms"])
677+
// For composite clients - all timeouts are supported
678+
let all_timeout_fields = vec![
679+
"connect_timeout_ms",
680+
"request_timeout_ms",
681+
"time_to_first_token_timeout_ms",
682+
"idle_timeout_ms",
683+
"total_timeout_ms",
684+
];
685+
686+
if let Some(suggestion) =
687+
find_best_match(unrecognized_field, &all_timeout_fields)
674688
{
675689
format!(
676690
"Unrecognized field '{unrecognized_field}' in http configuration block. Did you mean '{suggestion}'? \
677-
Composite clients (fallback/round-robin) only support: total_timeout_ms"
691+
Composite clients (fallback/round-robin) support: connect_timeout_ms, request_timeout_ms, \
692+
time_to_first_token_timeout_ms, idle_timeout_ms, total_timeout_ms"
678693
)
679694
} else {
680695
format!(
681696
"Unrecognized field '{unrecognized_field}' in http configuration block. \
682-
Composite clients (fallback/round-robin) only support: total_timeout_ms"
697+
Composite clients (fallback/round-robin) support: connect_timeout_ms, request_timeout_ms, \
698+
time_to_first_token_timeout_ms, idle_timeout_ms, total_timeout_ms"
683699
)
684700
}
685701
} else {

engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,18 @@ pub async fn orchestrate(
6262
let mut results = Vec::new();
6363
let mut total_sleep_duration = std::time::Duration::from_secs(0);
6464

65+
// Extract total_timeout_ms from strategy if present
66+
let total_timeout_ms: Option<u64> = iter.first().and_then(|node| {
67+
node.scope.scope.iter().find_map(|scope| match scope {
68+
super::ExecutionScope::Fallback(strategy, _) => strategy.http_config.total_timeout_ms,
69+
super::ExecutionScope::RoundRobin(strategy, _) => strategy.http_config.total_timeout_ms,
70+
_ => None,
71+
})
72+
});
73+
74+
// Track the start time for total timeout
75+
let start_time = web_time::Instant::now();
76+
6577
// Create a future that either waits for cancellation or never completes
6678
let cancel_future = match cancel_tripwire {
6779
Some(tripwire) => Box::pin(async move {
@@ -73,6 +85,38 @@ pub async fn orchestrate(
7385
tokio::pin!(cancel_future);
7486

7587
for node in iter {
88+
// Check for total timeout before starting each client
89+
if let Some(timeout_ms) = total_timeout_ms {
90+
let elapsed = start_time.elapsed();
91+
if elapsed.as_millis() >= timeout_ms as u128 {
92+
let cancel_scope = node.scope.clone();
93+
results.push((
94+
cancel_scope,
95+
LLMResponse::LLMFailure(crate::internal::llm_client::LLMErrorResponse {
96+
client: node.provider.name().to_string(),
97+
model: None,
98+
message: format!("Total timeout of {}ms exceeded", timeout_ms),
99+
code: crate::internal::llm_client::ErrorCode::Timeout,
100+
prompt: internal_baml_jinja::RenderedPrompt::Completion(String::new()),
101+
start_time: web_time::SystemTime::now(),
102+
latency: elapsed,
103+
request_options: Default::default(),
104+
}),
105+
Some(Err(anyhow::anyhow!(
106+
crate::errors::ExposedError::TimeoutError {
107+
client_name: node.provider.name().to_string(),
108+
message: format!(
109+
"Total timeout of {}ms exceeded (elapsed: {}ms)",
110+
timeout_ms,
111+
elapsed.as_millis()
112+
),
113+
}
114+
))),
115+
));
116+
break;
117+
}
118+
}
119+
76120
// Check for cancellation at the start of each iteration
77121
let cancel_scope = node.scope.clone();
78122
tokio::select! {

engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl std::fmt::Display for ExecutionScope {
5353
write!(f, "RoundRobin({}, {})", strategy.name, index)
5454
}
5555
ExecutionScope::Fallback(strategy, index) => {
56-
write!(f, "Fallback({strategy}, {index})")
56+
write!(f, "Fallback({}, {})", strategy.name, index)
5757
}
5858
}
5959
}
@@ -152,8 +152,8 @@ pub enum ExecutionScope {
152152
Retry(String, usize, Duration),
153153
// StrategyName, ClientIndex
154154
RoundRobin(Arc<RoundRobinStrategy>, usize),
155-
// StrategyName, ClientIndex
156-
Fallback(String, usize),
155+
// Strategy (with http_config), ClientIndex
156+
Fallback(Arc<super::strategy::fallback::FallbackStrategy>, usize),
157157
}
158158

159159
pub type OrchestratorNodeIterator = Vec<OrchestratorNode>;

engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,18 @@ where
172172
let mut results = Vec::new();
173173
let mut total_sleep_duration = web_time::Duration::from_secs(0);
174174

175+
// Extract total_timeout_ms from strategy if present
176+
let total_timeout_ms: Option<u64> = iter.first().and_then(|node| {
177+
node.scope.scope.iter().find_map(|scope| match scope {
178+
ExecutionScope::Fallback(strategy, _) => strategy.http_config.total_timeout_ms,
179+
ExecutionScope::RoundRobin(strategy, _) => strategy.http_config.total_timeout_ms,
180+
_ => None,
181+
})
182+
});
183+
184+
// Track the start time for total timeout
185+
let start_time = web_time::Instant::now();
186+
175187
// Create a future that either waits for cancellation or never completes
176188
let cancel_future = match cancel_tripwire {
177189
Some(tripwire) => Box::pin(async move {
@@ -184,6 +196,38 @@ where
184196

185197
//advanced curl viewing, use render_raw_curl on each node. TODO
186198
for node in iter {
199+
// Check for total timeout before starting each client
200+
if let Some(timeout_ms) = total_timeout_ms {
201+
let elapsed = start_time.elapsed();
202+
if elapsed.as_millis() >= timeout_ms as u128 {
203+
let cancel_scope = node.scope.clone();
204+
results.push((
205+
cancel_scope,
206+
LLMResponse::LLMFailure(crate::internal::llm_client::LLMErrorResponse {
207+
client: node.provider.name().to_string(),
208+
model: None,
209+
message: format!("Total timeout of {}ms exceeded", timeout_ms),
210+
code: crate::internal::llm_client::ErrorCode::Timeout,
211+
prompt: internal_baml_jinja::RenderedPrompt::Completion(String::new()),
212+
start_time: web_time::SystemTime::now(),
213+
latency: elapsed,
214+
request_options: Default::default(),
215+
}),
216+
Some(Err(anyhow::anyhow!(
217+
crate::errors::ExposedError::TimeoutError {
218+
client_name: node.provider.name().to_string(),
219+
message: format!(
220+
"Total timeout of {}ms exceeded (elapsed: {}ms)",
221+
timeout_ms,
222+
elapsed.as_millis()
223+
),
224+
}
225+
))),
226+
));
227+
break;
228+
}
229+
}
230+
187231
// Check for cancellation at the start of each iteration
188232
let cancel_scope = node.scope.clone();
189233
tokio::select! {

engine/baml-runtime/src/internal/llm_client/strategy/fallback.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,28 @@ use crate::{
1515
RuntimeContext,
1616
};
1717

18+
#[derive(Debug)]
1819
pub struct FallbackStrategy {
1920
pub name: String,
2021
pub(super) retry_policy: Option<String>,
2122
// TODO: We can add conditions to each client
2223
client_specs: Vec<ClientSpec>,
24+
pub http_config: internal_llm_client::HttpConfig,
2325
}
2426

2527
fn resolve_strategy(
2628
provider: &ClientProvider,
2729
properties: &UnresolvedClientProperty<()>,
2830
ctx: &RuntimeContext,
29-
) -> Result<Vec<ClientSpec>> {
31+
) -> Result<(Vec<ClientSpec>, internal_llm_client::HttpConfig)> {
3032
let properties = properties.resolve(provider, &ctx.eval_ctx(false))?;
3133
let ResolvedClientProperty::Fallback(props) = properties else {
3234
anyhow::bail!(
3335
"Invalid client property. Should have been a fallback property but got: {}",
3436
properties.name()
3537
);
3638
};
37-
Ok(props.strategy)
39+
Ok((props.strategy, props.http_config))
3840
}
3941

4042
impl TryFrom<(&ClientProperty, &RuntimeContext)> for FallbackStrategy {
@@ -43,11 +45,13 @@ impl TryFrom<(&ClientProperty, &RuntimeContext)> for FallbackStrategy {
4345
fn try_from(
4446
(client, ctx): (&ClientProperty, &RuntimeContext),
4547
) -> std::result::Result<Self, Self::Error> {
46-
let strategy = resolve_strategy(&client.provider, &client.unresolved_options()?, ctx)?;
48+
let (strategy, http_config) =
49+
resolve_strategy(&client.provider, &client.unresolved_options()?, ctx)?;
4750
Ok(Self {
4851
name: client.name.clone(),
4952
retry_policy: client.retry_policy.clone(),
5053
client_specs: strategy,
54+
http_config,
5155
})
5256
}
5357
}
@@ -56,16 +60,18 @@ impl TryFrom<(&ClientWalker<'_>, &RuntimeContext)> for FallbackStrategy {
5660
type Error = anyhow::Error;
5761

5862
fn try_from((client, ctx): (&ClientWalker, &RuntimeContext)) -> Result<Self> {
59-
let strategy = resolve_strategy(&client.elem().provider, client.options(), ctx)?;
63+
let (strategy, http_config) =
64+
resolve_strategy(&client.elem().provider, client.options(), ctx)?;
6065
Ok(Self {
6166
name: client.item.elem.name.clone(),
6267
retry_policy: client.retry_policy().as_ref().map(String::from),
6368
client_specs: strategy,
69+
http_config,
6470
})
6571
}
6672
}
6773

68-
impl IterOrchestrator for FallbackStrategy {
74+
impl IterOrchestrator for std::sync::Arc<FallbackStrategy> {
6975
fn iter_orchestrator<'a>(
7076
&self,
7177
state: &mut OrchestrationState,
@@ -83,7 +89,7 @@ impl IterOrchestrator for FallbackStrategy {
8389
let client = client.clone();
8490
Ok(client.iter_orchestrator(
8591
state,
86-
ExecutionScope::Fallback(self.name.clone(), idx).into(),
92+
ExecutionScope::Fallback(self.clone(), idx).into(),
8793
ctx,
8894
client_lookup,
8995
))

engine/baml-runtime/src/internal/llm_client/strategy/mod.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::sync::Arc;
22

33
use anyhow::Result;
4-
mod fallback;
4+
pub mod fallback;
55
pub mod roundrobin;
66

77
use internal_baml_core::ir::ClientWalker;
@@ -20,7 +20,7 @@ use crate::{
2020

2121
pub enum LLMStrategyProvider {
2222
RoundRobin(Arc<RoundRobinStrategy>),
23-
Fallback(FallbackStrategy),
23+
Fallback(Arc<FallbackStrategy>),
2424
}
2525

2626
impl std::fmt::Display for LLMStrategyProvider {
@@ -45,9 +45,9 @@ impl TryFrom<(&ClientWalker<'_>, &RuntimeContext)> for LLMStrategyProvider {
4545
StrategyClientProvider::RoundRobin => RoundRobinStrategy::try_from((client, ctx))
4646
.map(Arc::new)
4747
.map(LLMStrategyProvider::RoundRobin),
48-
StrategyClientProvider::Fallback => {
49-
FallbackStrategy::try_from((client, ctx)).map(LLMStrategyProvider::Fallback)
50-
}
48+
StrategyClientProvider::Fallback => FallbackStrategy::try_from((client, ctx))
49+
.map(Arc::new)
50+
.map(LLMStrategyProvider::Fallback),
5151
},
5252
_ => {
5353
anyhow::bail!("Unsupported strategy provider: {}", client.elem().provider,)
@@ -65,9 +65,9 @@ impl TryFrom<(&ClientProperty, &RuntimeContext)> for LLMStrategyProvider {
6565
StrategyClientProvider::RoundRobin => RoundRobinStrategy::try_from((client, ctx))
6666
.map(Arc::new)
6767
.map(LLMStrategyProvider::RoundRobin),
68-
StrategyClientProvider::Fallback => {
69-
FallbackStrategy::try_from((client, ctx)).map(LLMStrategyProvider::Fallback)
70-
}
68+
StrategyClientProvider::Fallback => FallbackStrategy::try_from((client, ctx))
69+
.map(Arc::new)
70+
.map(LLMStrategyProvider::Fallback),
7171
},
7272
other => {
7373
let options = ["round-robin", "fallback"];

0 commit comments

Comments
 (0)