Skip to content

Commit 6aa3794

Browse files
authored
feat(langchain): reference model profiles for provider strategy (#33974)
1 parent 189dcf7 commit 6aa3794

File tree

4 files changed

+75
-12
lines changed

4 files changed

+75
-12
lines changed

libs/langchain_v1/langchain/agents/factory.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,18 @@
6363

6464
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
6565

66+
FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
67+
# if langchain-model-profiles is not installed, these models are assumed to support
68+
# structured output
69+
"grok",
70+
"gpt-5",
71+
"gpt-4.1",
72+
"gpt-4o",
73+
"gpt-oss",
74+
"o3-pro",
75+
"o3-mini",
76+
]
77+
6678

6779
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
6880
"""Normalize middleware return value to ModelResponse."""
@@ -349,11 +361,13 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
349361
return []
350362

351363

352-
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
364+
def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = None) -> bool:
353365
"""Check if a model supports provider-specific structured output.
354366
355367
Args:
356368
model: Model name string or `BaseChatModel` instance.
369+
tools: Optional list of tools provided to the agent. Needed because some models
370+
don't support structured output together with tool calling.
357371
358372
Returns:
359373
`True` if the model supports provider-specific structured output, `False` otherwise.
@@ -362,11 +376,26 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
362376
if isinstance(model, str):
363377
model_name = model
364378
elif isinstance(model, BaseChatModel):
365-
model_name = getattr(model, "model_name", None)
379+
model_name = (
380+
getattr(model, "model_name", None)
381+
or getattr(model, "model", None)
382+
or getattr(model, "model_id", "")
383+
)
384+
try:
385+
model_profile = model.profile
386+
except ImportError:
387+
pass
388+
else:
389+
if (
390+
model_profile.get("structured_output")
391+
# We make an exception for Gemini models, which currently do not support
392+
# simultaneous tool use with structured output
393+
and not (tools and isinstance(model_name, str) and "gemini" in model_name.lower())
394+
):
395+
return True
366396

367397
return (
368-
"grok" in model_name.lower()
369-
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
398+
any(part in model_name.lower() for part in FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT)
370399
if model_name
371400
else False
372401
)
@@ -988,7 +1017,7 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |
9881017
effective_response_format: ResponseFormat | None
9891018
if isinstance(request.response_format, AutoStrategy):
9901019
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
991-
if _supports_provider_strategy(request.model):
1020+
if _supports_provider_strategy(request.model, tools=request.tools):
9921021
# Model supports provider strategy - use it
9931022
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
9941023
else:

libs/langchain_v1/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ test = [
5757
"pytest-mock",
5858
"syrupy>=4.0.2,<5.0.0",
5959
"toml>=0.10.2,<1.0.0",
60+
"langchain-model-profiles",
6061
"langchain-tests",
6162
"langchain-openai",
6263
]
@@ -75,6 +76,7 @@ test_integration = [
7576
"cassio>=0.1.0,<1.0.0",
7677
"langchainhub>=0.1.16,<1.0.0",
7778
"langchain-core",
79+
"langchain-model-profiles",
7880
"langchain-text-splitters",
7981
]
8082

@@ -83,6 +85,7 @@ prerelease = "allow"
8385

8486
[tool.uv.sources]
8587
langchain-core = { path = "../core", editable = true }
88+
langchain-model-profiles = { path = "../model-profiles", editable = true }
8689
langchain-tests = { path = "../standard-tests", editable = true }
8790
langchain-text-splitters = { path = "../text-splitters", editable = true }
8891
langchain-openai = { path = "../partners/openai", editable = true }

libs/langchain_v1/tests/unit_tests/agents/test_response_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def wrap_model_call(
790790
# Track which model is checked for provider strategy support
791791
calls = []
792792

793-
def mock_supports_provider_strategy(model) -> bool:
793+
def mock_supports_provider_strategy(model, tools) -> bool:
794794
"""Track which model is checked and return True for ProviderStrategy."""
795795
calls.append(model)
796796
return True

libs/langchain_v1/uv.lock

Lines changed: 37 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)