Skip to content

Commit 204fbe9

Browse files
committed
compound client timeouts
1 parent 38cfe9b commit 204fbe9

File tree

48 files changed

+4886
-66
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+4886
-66
lines changed

engine/baml-lib/baml/tests/validation_files/client/http_config_composite_wrong_field.baml

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,16 @@ client<llm> InvalidCompositeFields {
1010
options {
1111
strategy [BaseClient]
1212
http {
13-
request_timeout_ms 5000 // Not allowed for composite
14-
connect_timeout_ms 3000 // Not allowed for composite
13+
invalid_field_name 5000 // Unrecognized field
1514
}
1615
}
1716
}
1817

19-
// error: Unrecognized field 'request_timeout_ms' in http configuration block. Did you mean 'total_timeout_ms'? Composite clients (fallback/round-robin) only support: total_timeout_ms
18+
// error: Unrecognized field 'invalid_field_name' in http configuration block. Composite clients (fallback/round-robin) support: connect_timeout_ms, request_timeout_ms, time_to_first_token_timeout_ms, idle_timeout_ms, total_timeout_ms
2019
// --> client/http_config_composite_wrong_field.baml:12
2120
// |
2221
// 11 | strategy [BaseClient]
2322
// 12 | http {
24-
// 13 | request_timeout_ms 5000 // Not allowed for composite
25-
// 14 | connect_timeout_ms 3000 // Not allowed for composite
26-
// 15 | }
27-
// |
28-
// error: Unrecognized field 'connect_timeout_ms' in http configuration block. Did you mean 'total_timeout_ms'? Composite clients (fallback/round-robin) only support: total_timeout_ms
29-
// --> client/http_config_composite_wrong_field.baml:12
30-
// |
31-
// 11 | strategy [BaseClient]
32-
// 12 | http {
33-
// 13 | request_timeout_ms 5000 // Not allowed for composite
34-
// 14 | connect_timeout_ms 3000 // Not allowed for composite
35-
// 15 | }
23+
// 13 | invalid_field_name 5000 // Unrecognized field
24+
// 14 | }
3625
// |

engine/baml-lib/llm-client/src/clients/helpers.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -515,9 +515,18 @@ impl<Meta: Clone> PropertyHandler<Meta> {
515515
// Define allowed fields based on provider type
516516
let is_composite =
517517
provider_type == "fallback" || provider_type == "round-robin";
518+
// All timeouts are now allowed on both regular and composite clients
519+
// Composite clients additionally support total_timeout_ms
518520
let allowed_fields: HashSet<&str> = if is_composite {
519-
// Composite clients only support total_timeout_ms
520-
vec!["total_timeout_ms"].into_iter().collect()
521+
vec![
522+
"connect_timeout_ms",
523+
"request_timeout_ms",
524+
"time_to_first_token_timeout_ms",
525+
"idle_timeout_ms",
526+
"total_timeout_ms",
527+
]
528+
.into_iter()
529+
.collect()
521530
} else {
522531
// Regular clients support all timeout types except total_timeout_ms
523532
vec![
@@ -532,7 +541,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {
532541

533542
for (key, (_, value)) in config_map {
534543
match key.as_str() {
535-
"connect_timeout_ms" if !is_composite => {
544+
"connect_timeout_ms" => {
536545
let value_meta = value.meta().clone();
537546
match value.into_numeric() {
538547
Ok((val_str, _)) => {
@@ -556,7 +565,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {
556565
}
557566
}
558567
}
559-
"request_timeout_ms" if !is_composite => {
568+
"request_timeout_ms" => {
560569
let value_meta = value.meta().clone();
561570
match value.into_numeric() {
562571
Ok((val_str, _)) => {
@@ -580,7 +589,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {
580589
}
581590
}
582591
}
583-
"time_to_first_token_timeout_ms" if !is_composite => {
592+
"time_to_first_token_timeout_ms" => {
584593
let value_meta = value.meta().clone();
585594
match value.into_numeric() {
586595
Ok((val_str, _)) => {
@@ -603,7 +612,7 @@ impl<Meta: Clone> PropertyHandler<Meta> {
603612
}
604613
}
605614
}
606-
"idle_timeout_ms" if !is_composite => {
615+
"idle_timeout_ms" => {
607616
let value_meta = value.meta().clone();
608617
match value.into_numeric() {
609618
Ok((val_str, _)) => {
@@ -665,21 +674,28 @@ impl<Meta: Clone> PropertyHandler<Meta> {
665674
// Build error messages with suggestions
666675
for unrecognized_field in &unrecognized_fields {
667676
let error_msg = if is_composite {
668-
// For composite clients
669-
if unrecognized_field == "total_timeout_ms" {
670-
// This shouldn't happen as it's in the allowed list for composites
671-
continue;
672-
} else if let Some(suggestion) =
673-
find_best_match(unrecognized_field, &["total_timeout_ms"])
677+
// For composite clients - all timeouts are supported
678+
let all_timeout_fields = vec![
679+
"connect_timeout_ms",
680+
"request_timeout_ms",
681+
"time_to_first_token_timeout_ms",
682+
"idle_timeout_ms",
683+
"total_timeout_ms",
684+
];
685+
686+
if let Some(suggestion) =
687+
find_best_match(unrecognized_field, &all_timeout_fields)
674688
{
675689
format!(
676690
"Unrecognized field '{unrecognized_field}' in http configuration block. Did you mean '{suggestion}'? \
677-
Composite clients (fallback/round-robin) only support: total_timeout_ms"
691+
Composite clients (fallback/round-robin) support: connect_timeout_ms, request_timeout_ms, \
692+
time_to_first_token_timeout_ms, idle_timeout_ms, total_timeout_ms"
678693
)
679694
} else {
680695
format!(
681696
"Unrecognized field '{unrecognized_field}' in http configuration block. \
682-
Composite clients (fallback/round-robin) only support: total_timeout_ms"
697+
Composite clients (fallback/round-robin) support: connect_timeout_ms, request_timeout_ms, \
698+
time_to_first_token_timeout_ms, idle_timeout_ms, total_timeout_ms"
683699
)
684700
}
685701
} else {

engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,18 @@ pub async fn orchestrate(
6262
let mut results = Vec::new();
6363
let mut total_sleep_duration = std::time::Duration::from_secs(0);
6464

65+
// Extract total_timeout_ms from strategy if present
66+
let total_timeout_ms: Option<u64> = iter.first().and_then(|node| {
67+
node.scope.scope.iter().find_map(|scope| match scope {
68+
super::ExecutionScope::Fallback(strategy, _) => strategy.http_config.total_timeout_ms,
69+
super::ExecutionScope::RoundRobin(strategy, _) => strategy.http_config.total_timeout_ms,
70+
_ => None,
71+
})
72+
});
73+
74+
// Track the start time for total timeout
75+
let start_time = web_time::Instant::now();
76+
6577
// Create a future that either waits for cancellation or never completes
6678
let cancel_future = match cancel_tripwire {
6779
Some(tripwire) => Box::pin(async move {
@@ -73,8 +85,59 @@ pub async fn orchestrate(
7385
tokio::pin!(cancel_future);
7486

7587
for node in iter {
88+
// Check for total timeout before starting each client
89+
if let Some(timeout_ms) = total_timeout_ms {
90+
let elapsed = start_time.elapsed();
91+
if elapsed.as_millis() >= timeout_ms as u128 {
92+
let cancel_scope = node.scope.clone();
93+
results.push((
94+
cancel_scope,
95+
LLMResponse::LLMFailure(crate::internal::llm_client::LLMErrorResponse {
96+
client: node.provider.name().to_string(),
97+
model: None,
98+
message: format!("Total timeout of {}ms exceeded", timeout_ms),
99+
code: crate::internal::llm_client::ErrorCode::Timeout,
100+
prompt: internal_baml_jinja::RenderedPrompt::Completion(String::new()),
101+
start_time: web_time::SystemTime::now(),
102+
latency: elapsed,
103+
request_options: Default::default(),
104+
}),
105+
Some(Err(anyhow::anyhow!(
106+
crate::errors::ExposedError::TimeoutError {
107+
client_name: node.provider.name().to_string(),
108+
message: format!(
109+
"Total timeout of {}ms exceeded (elapsed: {}ms)",
110+
timeout_ms,
111+
elapsed.as_millis()
112+
),
113+
}
114+
))),
115+
));
116+
break;
117+
}
118+
}
119+
76120
// Check for cancellation at the start of each iteration
77121
let cancel_scope = node.scope.clone();
122+
123+
// Clone data needed for timeout error before moving node
124+
let client_name_for_timeout = node.provider.name().to_string();
125+
126+
// Create a timeout future if total_timeout_ms is set
127+
let timeout_future = if let Some(timeout_ms) = total_timeout_ms {
128+
let remaining_time = timeout_ms.saturating_sub(start_time.elapsed().as_millis() as u64);
129+
if remaining_time == 0 {
130+
// Already exceeded, will be caught by the check above
131+
Box::pin(async_std::task::sleep(std::time::Duration::from_millis(0)))
132+
as std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
133+
} else {
134+
Box::pin(async_std::task::sleep(std::time::Duration::from_millis(remaining_time)))
135+
}
136+
} else {
137+
Box::pin(futures::future::pending())
138+
};
139+
tokio::pin!(timeout_future);
140+
78141
tokio::select! {
79142
biased;
80143

@@ -90,6 +153,34 @@ pub async fn orchestrate(
90153
));
91154
break;
92155
}
156+
_ = &mut timeout_future => {
157+
// Total timeout exceeded during client execution
158+
let elapsed = start_time.elapsed();
159+
results.push((
160+
cancel_scope,
161+
LLMResponse::LLMFailure(crate::internal::llm_client::LLMErrorResponse {
162+
client: client_name_for_timeout.clone(),
163+
model: None,
164+
message: format!("Total timeout of {}ms exceeded", total_timeout_ms.unwrap()),
165+
code: crate::internal::llm_client::ErrorCode::Timeout,
166+
prompt: internal_baml_jinja::RenderedPrompt::Completion(String::new()),
167+
start_time: web_time::SystemTime::now(),
168+
latency: elapsed,
169+
request_options: Default::default(),
170+
}),
171+
Some(Err(anyhow::anyhow!(
172+
crate::errors::ExposedError::TimeoutError {
173+
client_name: client_name_for_timeout,
174+
message: format!(
175+
"Total timeout of {}ms exceeded (elapsed: {}ms)",
176+
total_timeout_ms.unwrap(),
177+
elapsed.as_millis()
178+
),
179+
}
180+
))),
181+
));
182+
break;
183+
}
93184
result = async {
94185
let prompt = match node.render_prompt(ir, prompt, ctx, params).await {
95186
Ok(p) => p,

engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl std::fmt::Display for ExecutionScope {
5353
write!(f, "RoundRobin({}, {})", strategy.name, index)
5454
}
5555
ExecutionScope::Fallback(strategy, index) => {
56-
write!(f, "Fallback({strategy}, {index})")
56+
write!(f, "Fallback({}, {})", strategy.name, index)
5757
}
5858
}
5959
}
@@ -152,8 +152,8 @@ pub enum ExecutionScope {
152152
Retry(String, usize, Duration),
153153
// StrategyName, ClientIndex
154154
RoundRobin(Arc<RoundRobinStrategy>, usize),
155-
// StrategyName, ClientIndex
156-
Fallback(String, usize),
155+
// Strategy (with http_config), ClientIndex
156+
Fallback(Arc<super::strategy::fallback::FallbackStrategy>, usize),
157157
}
158158

159159
pub type OrchestratorNodeIterator = Vec<OrchestratorNode>;

engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,18 @@ where
172172
let mut results = Vec::new();
173173
let mut total_sleep_duration = web_time::Duration::from_secs(0);
174174

175+
// Extract total_timeout_ms from strategy if present
176+
let total_timeout_ms: Option<u64> = iter.first().and_then(|node| {
177+
node.scope.scope.iter().find_map(|scope| match scope {
178+
ExecutionScope::Fallback(strategy, _) => strategy.http_config.total_timeout_ms,
179+
ExecutionScope::RoundRobin(strategy, _) => strategy.http_config.total_timeout_ms,
180+
_ => None,
181+
})
182+
});
183+
184+
// Track the start time for total timeout
185+
let start_time = web_time::Instant::now();
186+
175187
// Create a future that either waits for cancellation or never completes
176188
let cancel_future = match cancel_tripwire {
177189
Some(tripwire) => Box::pin(async move {
@@ -184,8 +196,59 @@ where
184196

185197
//advanced curl viewing, use render_raw_curl on each node. TODO
186198
for node in iter {
199+
// Check for total timeout before starting each client
200+
if let Some(timeout_ms) = total_timeout_ms {
201+
let elapsed = start_time.elapsed();
202+
if elapsed.as_millis() >= timeout_ms as u128 {
203+
let cancel_scope = node.scope.clone();
204+
results.push((
205+
cancel_scope,
206+
LLMResponse::LLMFailure(crate::internal::llm_client::LLMErrorResponse {
207+
client: node.provider.name().to_string(),
208+
model: None,
209+
message: format!("Total timeout of {}ms exceeded", timeout_ms),
210+
code: crate::internal::llm_client::ErrorCode::Timeout,
211+
prompt: internal_baml_jinja::RenderedPrompt::Completion(String::new()),
212+
start_time: web_time::SystemTime::now(),
213+
latency: elapsed,
214+
request_options: Default::default(),
215+
}),
216+
Some(Err(anyhow::anyhow!(
217+
crate::errors::ExposedError::TimeoutError {
218+
client_name: node.provider.name().to_string(),
219+
message: format!(
220+
"Total timeout of {}ms exceeded (elapsed: {}ms)",
221+
timeout_ms,
222+
elapsed.as_millis()
223+
),
224+
}
225+
))),
226+
));
227+
break;
228+
}
229+
}
230+
187231
// Check for cancellation at the start of each iteration
188232
let cancel_scope = node.scope.clone();
233+
234+
// Clone data needed for timeout error before moving node
235+
let client_name_for_timeout = node.provider.name().to_string();
236+
237+
// Create a timeout future if total_timeout_ms is set
238+
let timeout_future = if let Some(timeout_ms) = total_timeout_ms {
239+
let remaining_time = timeout_ms.saturating_sub(start_time.elapsed().as_millis() as u64);
240+
if remaining_time == 0 {
241+
// Already exceeded, will be caught by the check above
242+
Box::pin(async_std::task::sleep(std::time::Duration::from_millis(0)))
243+
as std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
244+
} else {
245+
Box::pin(async_std::task::sleep(std::time::Duration::from_millis(remaining_time)))
246+
}
247+
} else {
248+
Box::pin(futures::future::pending())
249+
};
250+
tokio::pin!(timeout_future);
251+
189252
tokio::select! {
190253
biased;
191254

@@ -201,6 +264,34 @@ where
201264
));
202265
break;
203266
}
267+
_ = &mut timeout_future => {
268+
// Total timeout exceeded during client execution
269+
let elapsed = start_time.elapsed();
270+
results.push((
271+
cancel_scope,
272+
LLMResponse::LLMFailure(crate::internal::llm_client::LLMErrorResponse {
273+
client: client_name_for_timeout.clone(),
274+
model: None,
275+
message: format!("Total timeout of {}ms exceeded", total_timeout_ms.unwrap()),
276+
code: crate::internal::llm_client::ErrorCode::Timeout,
277+
prompt: internal_baml_jinja::RenderedPrompt::Completion(String::new()),
278+
start_time: web_time::SystemTime::now(),
279+
latency: elapsed,
280+
request_options: Default::default(),
281+
}),
282+
Some(Err(anyhow::anyhow!(
283+
crate::errors::ExposedError::TimeoutError {
284+
client_name: client_name_for_timeout,
285+
message: format!(
286+
"Total timeout of {}ms exceeded (elapsed: {}ms)",
287+
total_timeout_ms.unwrap(),
288+
elapsed.as_millis()
289+
),
290+
}
291+
))),
292+
));
293+
break;
294+
}
204295
result = async {
205296
let prompt = match node.render_prompt(ir, prompt, ctx, params).await {
206297
Ok(p) => p,

0 commit comments

Comments
 (0)