@@ -120,8 +120,15 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
120
120
return []
121
121
122
122
123
- def _supports_native_structured_output (model : str | BaseChatModel ) -> bool :
124
- """Check if a model supports native structured output."""
123
+ def _supports_provider_strategy (model : str | BaseChatModel ) -> bool :
124
+ """Check if a model supports provider-specific structured output.
125
+
126
+ Args:
127
+ model: Model name string or BaseChatModel instance.
128
+
129
+ Returns:
130
+ ``True`` if the model supports provider-specific structured output, ``False`` otherwise.
131
+ """
125
132
model_name : str | None = None
126
133
if isinstance (model , str ):
127
134
model_name = model
@@ -186,28 +193,25 @@ def create_agent( # noqa: PLR0915
186
193
if tools is None :
187
194
tools = []
188
195
189
- # Setup structured output
190
- structured_output_tools : dict [str , OutputToolBinding ] = {}
191
- native_output_binding : ProviderStrategyBinding | None = None
196
+ # Convert response format and setup structured output tools
197
+ # Raw schemas are converted to ToolStrategy upfront to calculate tools during agent creation.
198
+ # If auto-detection is needed, the strategy may be replaced with ProviderStrategy later.
199
+ initial_response_format : ToolStrategy | ProviderStrategy | None
200
+ is_auto_detect : bool
201
+ if response_format is None :
202
+ initial_response_format , is_auto_detect = None , False
203
+ elif isinstance (response_format , (ToolStrategy , ProviderStrategy )):
204
+ # Preserve explicitly requested strategies
205
+ initial_response_format , is_auto_detect = response_format , False
206
+ else :
207
+ # Raw schema - convert to ToolStrategy for now (may be replaced with ProviderStrategy)
208
+ initial_response_format , is_auto_detect = ToolStrategy (schema = response_format ), True
192
209
193
- if response_format is not None :
194
- if not isinstance (response_format , (ToolStrategy , ProviderStrategy )):
195
- # Auto-detect strategy based on model capabilities
196
- if _supports_native_structured_output (model ):
197
- response_format = ProviderStrategy (schema = response_format )
198
- else :
199
- response_format = ToolStrategy (schema = response_format )
200
-
201
- if isinstance (response_format , ToolStrategy ):
202
- # Setup tools strategy for structured output
203
- for response_schema in response_format .schema_specs :
204
- structured_tool_info = OutputToolBinding .from_schema_spec (response_schema )
205
- structured_output_tools [structured_tool_info .tool .name ] = structured_tool_info
206
- elif isinstance (response_format , ProviderStrategy ):
207
- # Setup native strategy
208
- native_output_binding = ProviderStrategyBinding .from_schema_spec (
209
- response_format .schema_spec
210
- )
210
+ structured_output_tools : dict [str , OutputToolBinding ] = {}
211
+ if isinstance (initial_response_format , ToolStrategy ):
212
+ for response_schema in initial_response_format .schema_specs :
213
+ structured_tool_info = OutputToolBinding .from_schema_spec (response_schema )
214
+ structured_output_tools [structured_tool_info .tool .name ] = structured_tool_info
211
215
middleware_tools = [t for m in middleware for t in getattr (m , "tools" , [])]
212
216
213
217
# Setup tools
@@ -280,18 +284,29 @@ def create_agent( # noqa: PLR0915
280
284
context_schema = context_schema ,
281
285
)
282
286
283
- def _handle_model_output (output : AIMessage ) -> dict [str , Any ]:
284
- """Handle model output including structured responses."""
285
- # Handle structured output with native strategy
286
- if isinstance (response_format , ProviderStrategy ):
287
- if not output .tool_calls and native_output_binding :
288
- structured_response = native_output_binding .parse (output )
287
+ def _handle_model_output (
288
+ output : AIMessage , effective_response_format : ResponseFormat | None
289
+ ) -> dict [str , Any ]:
290
+ """Handle model output including structured responses.
291
+
292
+ Args:
293
+ output: The AI message output from the model.
294
+ effective_response_format: The actual strategy used
295
+ (may differ from initial if auto-detected).
296
+ """
297
+ # Handle structured output with provider strategy
298
+ if isinstance (effective_response_format , ProviderStrategy ):
299
+ if not output .tool_calls :
300
+ provider_strategy_binding = ProviderStrategyBinding .from_schema_spec (
301
+ effective_response_format .schema_spec
302
+ )
303
+ structured_response = provider_strategy_binding .parse (output )
289
304
return {"messages" : [output ], "structured_response" : structured_response }
290
305
return {"messages" : [output ]}
291
306
292
- # Handle structured output with tools strategy
307
+ # Handle structured output with tool strategy
293
308
if (
294
- isinstance (response_format , ToolStrategy )
309
+ isinstance (effective_response_format , ToolStrategy )
295
310
and isinstance (output , AIMessage )
296
311
and output .tool_calls
297
312
):
@@ -306,7 +321,7 @@ def _handle_model_output(output: AIMessage) -> dict[str, Any]:
306
321
tool_names = [tc ["name" ] for tc in structured_tool_calls ]
307
322
exception = MultipleStructuredOutputsError (tool_names )
308
323
should_retry , error_message = _handle_structured_output_error (
309
- exception , response_format
324
+ exception , effective_response_format
310
325
)
311
326
if not should_retry :
312
327
raise exception
@@ -329,8 +344,8 @@ def _handle_model_output(output: AIMessage) -> dict[str, Any]:
329
344
structured_response = structured_tool_binding .parse (tool_call ["args" ])
330
345
331
346
tool_message_content = (
332
- response_format .tool_message_content
333
- if response_format .tool_message_content
347
+ effective_response_format .tool_message_content
348
+ if effective_response_format .tool_message_content
334
349
else f"Returning structured response: { structured_response } "
335
350
)
336
351
@@ -348,7 +363,7 @@ def _handle_model_output(output: AIMessage) -> dict[str, Any]:
348
363
except Exception as exc : # noqa: BLE001
349
364
exception = StructuredOutputValidationError (tool_call ["name" ], exc )
350
365
should_retry , error_message = _handle_structured_output_error (
351
- exception , response_format
366
+ exception , effective_response_format
352
367
)
353
368
if not should_retry :
354
369
raise exception
@@ -366,11 +381,20 @@ def _handle_model_output(output: AIMessage) -> dict[str, Any]:
366
381
367
382
return {"messages" : [output ]}
368
383
369
- def _get_bound_model (request : ModelRequest ) -> Runnable :
370
- """Get the model with appropriate tool bindings."""
371
- # Get actual tool objects from tool names
372
- tools_by_name = { t . name : t for t in default_tools }
384
+ def _get_bound_model (request : ModelRequest ) -> tuple [ Runnable , ResponseFormat | None ] :
385
+ """Get the model with appropriate tool bindings.
386
+
387
+ Performs auto-detection of strategy if needed based on model capabilities.
373
388
389
+ Args:
390
+ request: The model request containing model, tools, and response format.
391
+
392
+ Returns:
393
+ Tuple of (bound_model, effective_response_format) where ``effective_response_format``
394
+ is the actual strategy used (may differ from initial if auto-detected).
395
+ """
396
+ # Validate requested tools are available
397
+ tools_by_name = {t .name : t for t in default_tools }
374
398
unknown_tools = [name for name in request .tools if name not in tools_by_name ]
375
399
if unknown_tools :
376
400
available_tools = sorted (tools_by_name .keys ())
@@ -389,31 +413,57 @@ def _get_bound_model(request: ModelRequest) -> Runnable:
389
413
390
414
requested_tools = [tools_by_name [name ] for name in request .tools ]
391
415
392
- if isinstance (response_format , ProviderStrategy ):
393
- # Use native structured output
394
- kwargs = response_format .to_model_kwargs ()
395
- return request .model .bind_tools (
396
- requested_tools , strict = True , ** kwargs , ** request .model_settings
416
+ # Determine effective response format (auto-detect if needed)
417
+ effective_response_format : ResponseFormat | None = request .response_format
418
+ if (
419
+ # User provided raw schema - auto-detect best strategy based on model
420
+ is_auto_detect
421
+ and isinstance (request .response_format , ToolStrategy )
422
+ and
423
+ # Model supports provider strategy - use it instead
424
+ _supports_provider_strategy (request .model )
425
+ ):
426
+ effective_response_format = ProviderStrategy (schema = response_format ) # type: ignore[arg-type]
427
+ # else: keep ToolStrategy from initial conversion
428
+
429
+ # Bind model based on effective response format
430
+ if isinstance (effective_response_format , ProviderStrategy ):
431
+ # Use provider-specific structured output
432
+ kwargs = effective_response_format .to_model_kwargs ()
433
+ return (
434
+ request .model .bind_tools (
435
+ requested_tools , strict = True , ** kwargs , ** request .model_settings
436
+ ),
437
+ effective_response_format ,
397
438
)
398
- if isinstance (response_format , ToolStrategy ):
439
+
440
+ if isinstance (effective_response_format , ToolStrategy ):
441
+ # Force tool use if we have structured output tools
399
442
tool_choice = "any" if structured_output_tools else request .tool_choice
400
- return request .model .bind_tools (
401
- requested_tools , tool_choice = tool_choice , ** request .model_settings
443
+ return (
444
+ request .model .bind_tools (
445
+ requested_tools , tool_choice = tool_choice , ** request .model_settings
446
+ ),
447
+ effective_response_format ,
402
448
)
403
- # Standard model binding
449
+
450
+ # No structured output - standard model binding
404
451
if requested_tools :
405
- return request .model .bind_tools (
406
- requested_tools , tool_choice = request .tool_choice , ** request .model_settings
452
+ return (
453
+ request .model .bind_tools (
454
+ requested_tools , tool_choice = request .tool_choice , ** request .model_settings
455
+ ),
456
+ None ,
407
457
)
408
- return request .model .bind (** request .model_settings )
458
+ return request .model .bind (** request .model_settings ), None
409
459
410
460
def model_request (state : AgentState , runtime : Runtime [ContextT ]) -> dict [str , Any ]:
411
461
"""Sync model request handler with sequential middleware processing."""
412
462
request = ModelRequest (
413
463
model = model ,
414
464
tools = [t .name for t in default_tools ],
415
465
system_prompt = system_prompt ,
416
- response_format = response_format ,
466
+ response_format = initial_response_format ,
417
467
messages = state ["messages" ],
418
468
tool_choice = None ,
419
469
)
@@ -431,8 +481,8 @@ def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, An
431
481
)
432
482
raise TypeError (msg )
433
483
434
- # Get the final model and messages
435
- model_ = _get_bound_model (request )
484
+ # Get the bound model (with auto-detection if needed)
485
+ model_ , effective_response_format = _get_bound_model (request )
436
486
messages = request .messages
437
487
if request .system_prompt :
438
488
messages = [SystemMessage (request .system_prompt ), * messages ]
@@ -441,7 +491,7 @@ def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, An
441
491
return {
442
492
"thread_model_call_count" : state .get ("thread_model_call_count" , 0 ) + 1 ,
443
493
"run_model_call_count" : state .get ("run_model_call_count" , 0 ) + 1 ,
444
- ** _handle_model_output (output ),
494
+ ** _handle_model_output (output , effective_response_format ),
445
495
}
446
496
447
497
async def amodel_request (state : AgentState , runtime : Runtime [ContextT ]) -> dict [str , Any ]:
@@ -450,7 +500,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
450
500
model = model ,
451
501
tools = [t .name for t in default_tools ],
452
502
system_prompt = system_prompt ,
453
- response_format = response_format ,
503
+ response_format = initial_response_format ,
454
504
messages = state ["messages" ],
455
505
tool_choice = None ,
456
506
)
@@ -459,8 +509,8 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
459
509
for m in middleware_w_modify_model_request :
460
510
await m .amodify_model_request (request , state , runtime )
461
511
462
- # Get the final model and messages
463
- model_ = _get_bound_model (request )
512
+ # Get the bound model (with auto-detection if needed)
513
+ model_ , effective_response_format = _get_bound_model (request )
464
514
messages = request .messages
465
515
if request .system_prompt :
466
516
messages = [SystemMessage (request .system_prompt ), * messages ]
@@ -469,7 +519,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
469
519
return {
470
520
"thread_model_call_count" : state .get ("thread_model_call_count" , 0 ) + 1 ,
471
521
"run_model_call_count" : state .get ("run_model_call_count" , 0 ) + 1 ,
472
- ** _handle_model_output (output ),
522
+ ** _handle_model_output (output , effective_response_format ),
473
523
}
474
524
475
525
# Use sync or async based on model capabilities
0 commit comments