Skip to content

Commit 905c6d7

Browse files
authored
fix(langchain_v1): handle switching resposne format strategy based on model identity (#33259)
Change response format strategy dynamically based on model. After this PR there are two remaining issues: - [ ] Review binding of tools used for output to ToolNode (shouldn't be required) - [ ] Update ModelRequest to also support the original schema provided by the user (to correctly support auto mode)
1 parent acd1aa8 commit 905c6d7

File tree

2 files changed

+201
-60
lines changed

2 files changed

+201
-60
lines changed

libs/langchain_v1/langchain/agents/middleware_agent.py

Lines changed: 109 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,15 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
120120
return []
121121

122122

123-
def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
124-
"""Check if a model supports native structured output."""
123+
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
124+
"""Check if a model supports provider-specific structured output.
125+
126+
Args:
127+
model: Model name string or BaseChatModel instance.
128+
129+
Returns:
130+
``True`` if the model supports provider-specific structured output, ``False`` otherwise.
131+
"""
125132
model_name: str | None = None
126133
if isinstance(model, str):
127134
model_name = model
@@ -186,28 +193,25 @@ def create_agent( # noqa: PLR0915
186193
if tools is None:
187194
tools = []
188195

189-
# Setup structured output
190-
structured_output_tools: dict[str, OutputToolBinding] = {}
191-
native_output_binding: ProviderStrategyBinding | None = None
196+
# Convert response format and setup structured output tools
197+
# Raw schemas are converted to ToolStrategy upfront to calculate tools during agent creation.
198+
# If auto-detection is needed, the strategy may be replaced with ProviderStrategy later.
199+
initial_response_format: ToolStrategy | ProviderStrategy | None
200+
is_auto_detect: bool
201+
if response_format is None:
202+
initial_response_format, is_auto_detect = None, False
203+
elif isinstance(response_format, (ToolStrategy, ProviderStrategy)):
204+
# Preserve explicitly requested strategies
205+
initial_response_format, is_auto_detect = response_format, False
206+
else:
207+
# Raw schema - convert to ToolStrategy for now (may be replaced with ProviderStrategy)
208+
initial_response_format, is_auto_detect = ToolStrategy(schema=response_format), True
192209

193-
if response_format is not None:
194-
if not isinstance(response_format, (ToolStrategy, ProviderStrategy)):
195-
# Auto-detect strategy based on model capabilities
196-
if _supports_native_structured_output(model):
197-
response_format = ProviderStrategy(schema=response_format)
198-
else:
199-
response_format = ToolStrategy(schema=response_format)
200-
201-
if isinstance(response_format, ToolStrategy):
202-
# Setup tools strategy for structured output
203-
for response_schema in response_format.schema_specs:
204-
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
205-
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
206-
elif isinstance(response_format, ProviderStrategy):
207-
# Setup native strategy
208-
native_output_binding = ProviderStrategyBinding.from_schema_spec(
209-
response_format.schema_spec
210-
)
210+
structured_output_tools: dict[str, OutputToolBinding] = {}
211+
if isinstance(initial_response_format, ToolStrategy):
212+
for response_schema in initial_response_format.schema_specs:
213+
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
214+
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
211215
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
212216

213217
# Setup tools
@@ -280,18 +284,29 @@ def create_agent( # noqa: PLR0915
280284
context_schema=context_schema,
281285
)
282286

283-
def _handle_model_output(output: AIMessage) -> dict[str, Any]:
284-
"""Handle model output including structured responses."""
285-
# Handle structured output with native strategy
286-
if isinstance(response_format, ProviderStrategy):
287-
if not output.tool_calls and native_output_binding:
288-
structured_response = native_output_binding.parse(output)
287+
def _handle_model_output(
288+
output: AIMessage, effective_response_format: ResponseFormat | None
289+
) -> dict[str, Any]:
290+
"""Handle model output including structured responses.
291+
292+
Args:
293+
output: The AI message output from the model.
294+
effective_response_format: The actual strategy used
295+
(may differ from initial if auto-detected).
296+
"""
297+
# Handle structured output with provider strategy
298+
if isinstance(effective_response_format, ProviderStrategy):
299+
if not output.tool_calls:
300+
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
301+
effective_response_format.schema_spec
302+
)
303+
structured_response = provider_strategy_binding.parse(output)
289304
return {"messages": [output], "structured_response": structured_response}
290305
return {"messages": [output]}
291306

292-
# Handle structured output with tools strategy
307+
# Handle structured output with tool strategy
293308
if (
294-
isinstance(response_format, ToolStrategy)
309+
isinstance(effective_response_format, ToolStrategy)
295310
and isinstance(output, AIMessage)
296311
and output.tool_calls
297312
):
@@ -306,7 +321,7 @@ def _handle_model_output(output: AIMessage) -> dict[str, Any]:
306321
tool_names = [tc["name"] for tc in structured_tool_calls]
307322
exception = MultipleStructuredOutputsError(tool_names)
308323
should_retry, error_message = _handle_structured_output_error(
309-
exception, response_format
324+
exception, effective_response_format
310325
)
311326
if not should_retry:
312327
raise exception
@@ -329,8 +344,8 @@ def _handle_model_output(output: AIMessage) -> dict[str, Any]:
329344
structured_response = structured_tool_binding.parse(tool_call["args"])
330345

331346
tool_message_content = (
332-
response_format.tool_message_content
333-
if response_format.tool_message_content
347+
effective_response_format.tool_message_content
348+
if effective_response_format.tool_message_content
334349
else f"Returning structured response: {structured_response}"
335350
)
336351

@@ -348,7 +363,7 @@ def _handle_model_output(output: AIMessage) -> dict[str, Any]:
348363
except Exception as exc: # noqa: BLE001
349364
exception = StructuredOutputValidationError(tool_call["name"], exc)
350365
should_retry, error_message = _handle_structured_output_error(
351-
exception, response_format
366+
exception, effective_response_format
352367
)
353368
if not should_retry:
354369
raise exception
@@ -366,11 +381,20 @@ def _handle_model_output(output: AIMessage) -> dict[str, Any]:
366381

367382
return {"messages": [output]}
368383

369-
def _get_bound_model(request: ModelRequest) -> Runnable:
370-
"""Get the model with appropriate tool bindings."""
371-
# Get actual tool objects from tool names
372-
tools_by_name = {t.name: t for t in default_tools}
384+
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
385+
"""Get the model with appropriate tool bindings.
386+
387+
Performs auto-detection of strategy if needed based on model capabilities.
373388
389+
Args:
390+
request: The model request containing model, tools, and response format.
391+
392+
Returns:
393+
Tuple of (bound_model, effective_response_format) where ``effective_response_format``
394+
is the actual strategy used (may differ from initial if auto-detected).
395+
"""
396+
# Validate requested tools are available
397+
tools_by_name = {t.name: t for t in default_tools}
374398
unknown_tools = [name for name in request.tools if name not in tools_by_name]
375399
if unknown_tools:
376400
available_tools = sorted(tools_by_name.keys())
@@ -389,31 +413,57 @@ def _get_bound_model(request: ModelRequest) -> Runnable:
389413

390414
requested_tools = [tools_by_name[name] for name in request.tools]
391415

392-
if isinstance(response_format, ProviderStrategy):
393-
# Use native structured output
394-
kwargs = response_format.to_model_kwargs()
395-
return request.model.bind_tools(
396-
requested_tools, strict=True, **kwargs, **request.model_settings
416+
# Determine effective response format (auto-detect if needed)
417+
effective_response_format: ResponseFormat | None = request.response_format
418+
if (
419+
# User provided raw schema - auto-detect best strategy based on model
420+
is_auto_detect
421+
and isinstance(request.response_format, ToolStrategy)
422+
and
423+
# Model supports provider strategy - use it instead
424+
_supports_provider_strategy(request.model)
425+
):
426+
effective_response_format = ProviderStrategy(schema=response_format) # type: ignore[arg-type]
427+
# else: keep ToolStrategy from initial conversion
428+
429+
# Bind model based on effective response format
430+
if isinstance(effective_response_format, ProviderStrategy):
431+
# Use provider-specific structured output
432+
kwargs = effective_response_format.to_model_kwargs()
433+
return (
434+
request.model.bind_tools(
435+
requested_tools, strict=True, **kwargs, **request.model_settings
436+
),
437+
effective_response_format,
397438
)
398-
if isinstance(response_format, ToolStrategy):
439+
440+
if isinstance(effective_response_format, ToolStrategy):
441+
# Force tool use if we have structured output tools
399442
tool_choice = "any" if structured_output_tools else request.tool_choice
400-
return request.model.bind_tools(
401-
requested_tools, tool_choice=tool_choice, **request.model_settings
443+
return (
444+
request.model.bind_tools(
445+
requested_tools, tool_choice=tool_choice, **request.model_settings
446+
),
447+
effective_response_format,
402448
)
403-
# Standard model binding
449+
450+
# No structured output - standard model binding
404451
if requested_tools:
405-
return request.model.bind_tools(
406-
requested_tools, tool_choice=request.tool_choice, **request.model_settings
452+
return (
453+
request.model.bind_tools(
454+
requested_tools, tool_choice=request.tool_choice, **request.model_settings
455+
),
456+
None,
407457
)
408-
return request.model.bind(**request.model_settings)
458+
return request.model.bind(**request.model_settings), None
409459

410460
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
411461
"""Sync model request handler with sequential middleware processing."""
412462
request = ModelRequest(
413463
model=model,
414464
tools=[t.name for t in default_tools],
415465
system_prompt=system_prompt,
416-
response_format=response_format,
466+
response_format=initial_response_format,
417467
messages=state["messages"],
418468
tool_choice=None,
419469
)
@@ -431,8 +481,8 @@ def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, An
431481
)
432482
raise TypeError(msg)
433483

434-
# Get the final model and messages
435-
model_ = _get_bound_model(request)
484+
# Get the bound model (with auto-detection if needed)
485+
model_, effective_response_format = _get_bound_model(request)
436486
messages = request.messages
437487
if request.system_prompt:
438488
messages = [SystemMessage(request.system_prompt), *messages]
@@ -441,7 +491,7 @@ def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, An
441491
return {
442492
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
443493
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
444-
**_handle_model_output(output),
494+
**_handle_model_output(output, effective_response_format),
445495
}
446496

447497
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
@@ -450,7 +500,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
450500
model=model,
451501
tools=[t.name for t in default_tools],
452502
system_prompt=system_prompt,
453-
response_format=response_format,
503+
response_format=initial_response_format,
454504
messages=state["messages"],
455505
tool_choice=None,
456506
)
@@ -459,8 +509,8 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
459509
for m in middleware_w_modify_model_request:
460510
await m.amodify_model_request(request, state, runtime)
461511

462-
# Get the final model and messages
463-
model_ = _get_bound_model(request)
512+
# Get the bound model (with auto-detection if needed)
513+
model_, effective_response_format = _get_bound_model(request)
464514
messages = request.messages
465515
if request.system_prompt:
466516
messages = [SystemMessage(request.system_prompt), *messages]
@@ -469,7 +519,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
469519
return {
470520
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
471521
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
472-
**_handle_model_output(output),
522+
**_handle_model_output(output, effective_response_format),
473523
}
474524

475525
# Use sync or async based on model capabilities

libs/langchain_v1/tests/unit_tests/agents/test_response_format.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Test suite for create_agent with structured output response_format permutations."""
22

3+
import json
4+
35
import pytest
46

57
from dataclasses import dataclass
6-
from typing import Union
8+
from typing import Union, Sequence, Any, Callable
79

810
from langchain_core.messages import HumanMessage
911
from langchain.agents import create_agent
@@ -13,10 +15,16 @@
1315
StructuredOutputValidationError,
1416
ToolStrategy,
1517
)
18+
from langchain.tools import tool
1619
from pydantic import BaseModel, Field
1720
from typing_extensions import TypedDict
1821

22+
from langchain.messages import AIMessage
23+
from langchain_core.messages import BaseMessage
24+
from langchain_core.language_models import LanguageModelInput
25+
from langchain_core.runnables import Runnable
1926
from tests.unit_tests.agents.model import FakeToolCallingModel
27+
from langchain.tools import BaseTool
2028

2129

2230
# Test data models
@@ -676,6 +684,89 @@ def test_json_schema(self) -> None:
676684
assert len(response["messages"]) == 4
677685

678686

687+
class TestDynamicModelWithResponseFormat:
688+
"""Test response_format with middleware that modifies the model."""
689+
690+
def test_middleware_model_swap_provider_to_tool_strategy(self) -> None:
691+
"""Test that strategy resolution is deferred until after middleware modifies the model.
692+
693+
Verifies that when a raw schema is provided, ``_supports_provider_strategy`` is called
694+
on the middleware-modified model (not the original), ensuring the correct strategy is
695+
selected based on the final model's capabilities.
696+
"""
697+
from unittest.mock import patch
698+
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
699+
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
700+
701+
# Custom model that we'll use to test whether the tool strategy is applied
702+
# correctly at runtime.
703+
class CustomModel(GenericFakeChatModel):
704+
tool_bindings: list[Any] = []
705+
706+
def bind_tools(
707+
self,
708+
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
709+
**kwargs: Any,
710+
) -> Runnable[LanguageModelInput, BaseMessage]:
711+
# Record every tool binding event.
712+
self.tool_bindings.append(tools)
713+
return self
714+
715+
model = CustomModel(
716+
messages=iter(
717+
[
718+
# Simulate model returning structured output directly
719+
# (this is what provider strategy would do)
720+
json.dumps(WEATHER_DATA),
721+
]
722+
)
723+
)
724+
725+
# Create middleware that swaps the model in the request
726+
class ModelSwappingMiddleware(AgentMiddleware):
727+
def modify_model_request(self, request: ModelRequest, state, runtime) -> ModelRequest:
728+
# Replace the model with our custom test model
729+
request.model = model
730+
return request
731+
732+
# Track which model is checked for provider strategy support
733+
calls = []
734+
735+
def mock_supports_provider_strategy(model) -> bool:
736+
"""Track which model is checked and return True for ProviderStrategy."""
737+
calls.append(model)
738+
return True
739+
740+
# Use raw Pydantic model (not wrapped in ToolStrategy or ProviderStrategy)
741+
# This should auto-detect strategy based on model capabilities
742+
agent = create_agent(
743+
model=model,
744+
tools=[],
745+
# Raw schema - should auto-detect strategy
746+
response_format=WeatherBaseModel,
747+
middleware=[ModelSwappingMiddleware()],
748+
)
749+
750+
with patch(
751+
"langchain.agents.middleware_agent._supports_provider_strategy",
752+
side_effect=mock_supports_provider_strategy,
753+
):
754+
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
755+
756+
# Verify strategy resolution was deferred: check was called once during _get_bound_model
757+
assert len(calls) == 1
758+
759+
# Verify successful parsing of JSON as structured output via ProviderStrategy
760+
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
761+
# Two messages: Human input message and AI response with JSON content
762+
assert len(response["messages"]) == 2
763+
ai_message = response["messages"][1]
764+
assert isinstance(ai_message, AIMessage)
765+
# ProviderStrategy doesn't use tool calls - it parses content directly
766+
assert ai_message.tool_calls == []
767+
assert ai_message.content == json.dumps(WEATHER_DATA)
768+
769+
679770
def test_union_of_types() -> None:
680771
"""Test response_format as ProviderStrategy with Union (if supported)."""
681772
tool_calls = [

0 commit comments

Comments
 (0)