Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 11 additions & 37 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import override

from langchain_core._api.beta_decorator import beta
from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
AsyncCallbackManager,
Expand All @@ -34,6 +33,7 @@
LangSmithParams,
LanguageModelInput,
)
from langchain_core.language_models.profile import ModelProfile
from langchain_core.load import dumpd, dumps
from langchain_core.messages import (
AIMessage,
Expand Down Expand Up @@ -76,8 +76,6 @@
if TYPE_CHECKING:
import uuid

from langchain_model_profiles import ModelProfile # type: ignore[import-untyped]

from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool
Expand Down Expand Up @@ -339,6 +337,16 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):

"""

profile: ModelProfile | None = Field(default=None, exclude=True)
"""Profile detailing model capabilities.

If not specified, automatically loaded from the provider package on initialization
if data is available.

Example profile data includes context window sizes, supported modalities, or support
for tool calling, structured output, and other features.
"""

model_config = ConfigDict(
arbitrary_types_allowed=True,
)
Expand Down Expand Up @@ -1688,40 +1696,6 @@ class AnswerWithJustification(BaseModel):
return RunnableMap(raw=llm) | parser_with_fallback
return llm | output_parser

@property
@beta()
def profile(self) -> ModelProfile:
"""Return profiling information for the model.

This property relies on the `langchain-model-profiles` package to retrieve chat
model capabilities, such as context window sizes and supported features.

Raises:
ImportError: If `langchain-model-profiles` is not installed.

Returns:
A `ModelProfile` object containing profiling information for the model.
"""
try:
from langchain_model_profiles import get_model_profile # noqa: PLC0415
except ImportError as err:
informative_error_message = (
"To access model profiling information, please install the "
"`langchain-model-profiles` package: "
"`pip install langchain-model-profiles`."
)
raise ImportError(informative_error_message) from err

provider_id = self._llm_type
model_name = (
# Model name is not standardized across integrations. New integrations
# should prefer `model`.
getattr(self, "model", None)
or getattr(self, "model_name", None)
or getattr(self, "model_id", "")
)
return get_model_profile(provider_id, model_name) or {}


class SimpleChatModel(BaseChatModel):
"""Simplified implementation for a chat model to inherit from.
Expand Down
11 changes: 11 additions & 0 deletions libs/core/langchain_core/language_models/profile/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Model profile types and data loading utilities."""

from langchain_core.language_models.profile.model_profile import (
ModelProfile,
ModelProfileRegistry,
)

__all__ = [
"ModelProfile",
"ModelProfileRegistry",
]
128 changes: 128 additions & 0 deletions libs/core/langchain_core/language_models/profile/_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Data loader for model profiles with augmentation support."""

import json
import sys
from functools import cached_property
from pathlib import Path
from typing import Any

if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib


class _DataLoader:
"""Loads and merges model profile data from base and augmentations.

See the README in `data/augmentations` directory for more details on the
augmentation structure and merge priority.
"""

def __init__(self, data_dir: Path) -> None:
"""Initialize the loader.

Args:
data_dir: Path to the data directory containing models.json and
augmentations.
"""
self._data_dir = data_dir

@property
def _base_data_path(self) -> Path:
"""Get path to base data file.

`models.json` is the downloaded data from models.dev.
"""
return self._data_dir / "models.json"

@property
def _augmentations_file(self) -> Path:
"""Get path to profile augmentations file."""
return self._data_dir / "profile_augmentations.toml"

@cached_property
def _merged_data(self) -> dict[str, Any]:
"""Load and merge all data once at startup.

Merging order:

1. Base data from `models.json`
2. Provider-level augmentations from `[overrides]` in
`profile_augmentations.toml`
3. Model-level augmentations from `[overrides."model-name"]` in
`profile_augmentations.toml`

Returns:
Fully merged provider data with all augmentations applied.
"""
# Load base data; let exceptions propagate to user
with self._base_data_path.open("r") as f:
data = json.load(f)

# Load augmentations from profile_augmentations.toml
provider_aug, model_augs = self._load_augmentations()

# Merge augmentations into data
for provider_data in data.values():
models = provider_data.get("models", {})

for model_id, model_data in models.items():
# Apply provider-level augmentations
if provider_aug:
model_data.update(provider_aug)

# Apply model-level augmentations (highest priority)
if model_id in model_augs:
model_data.update(model_augs[model_id])

return data

def _load_augmentations(
self,
) -> tuple[dict[str, Any], dict[str, dict[str, Any]]]:
"""Load augmentations from profile_augmentations.toml.

Returns:
Tuple of (provider_augmentations, model_augmentations) where:
- provider_augmentations: dict of fields to apply to all models
- model_augmentations: dict mapping model IDs to their specific
augmentations
"""
if not self._augmentations_file.exists():
return {}, {}

with self._augmentations_file.open("rb") as f:
data = tomllib.load(f)

overrides = data.get("overrides", {})

# Separate provider-level augmentations from model-specific ones
# Model-specific overrides are nested dicts, while provider-level are primitives
provider_aug: dict[str, Any] = {}
model_augs: dict[str, dict[str, Any]] = {}

for key, value in overrides.items():
if isinstance(value, dict):
# This is a model-specific override like [overrides."claude-sonnet-4-5"]
model_augs[key] = value
else:
# This is a provider-level field
provider_aug[key] = value

return provider_aug, model_augs

def get_profile_data(self, provider_id: str) -> dict[str, Any] | None:
"""Get merged profile data for all models.

Args:
provider_id: The provider identifier.

Returns:
Merged model data `dict` or `None` if not found.
"""
provider = self._merged_data.get(provider_id)
if provider is None:
return None

return provider.get("models", {})
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Utilities for loading model profiles from provider packages."""

from functools import lru_cache
from pathlib import Path

from langchain_core.language_models.profile._data_loader import _DataLoader
from langchain_core.language_models.profile.model_profile import (
ModelProfileRegistry,
map_raw_data_to_profile,
)


def load_profiles_from_data_dir(
data_dir: Path, provider_id: str
) -> ModelProfileRegistry | None:
"""Load model profiles from a provider's data directory.

Args:
data_dir: Path to the provider's data directory.
provider_id: The provider identifier (e.g., 'anthropic', 'openai').

Returns:
ModelProfile with model capabilities, or None if not found.
"""
loader = _get_loader(data_dir)
data = loader.get_profile_data(provider_id)
if not data:
return None
return {
model_name: map_raw_data_to_profile(raw_profile)
for model_name, raw_profile in data.items()
}


@lru_cache(maxsize=32)
def _get_loader(data_dir: Path) -> _DataLoader:
"""Get a cached loader for a data directory.

Args:
data_dir: Path to the data directory.

Returns:
DataLoader instance.
"""
return _DataLoader(data_dir)
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Model profiles package."""
"""Model profile types and utilities."""

import re
from typing import Any

from typing_extensions import TypedDict

from langchain_model_profiles._data_loader import _DataLoader


class ModelProfile(TypedDict, total=False):
"""Model profile."""
"""Model profile.

Provides information about chat model capabilities, such as context window sizes
and supported features.
"""

# --- Input constraints ---

Expand Down Expand Up @@ -39,10 +41,10 @@ class ModelProfile(TypedDict, total=False):
# TODO: add more detail about formats? e.g. bytes or base64

image_tool_message: bool
"""TODO: description."""
"""Whether images can be included in tool messages."""

pdf_tool_message: bool
"""TODO: description."""
"""Whether PDFs can be included in tool messages."""

# --- Output constraints ---

Expand Down Expand Up @@ -77,77 +79,24 @@ class ModelProfile(TypedDict, total=False):
feature"""


_LOADER = _DataLoader()

_lc_type_to_provider_id = {
"openai-chat": "openai",
"azure-openai-chat": "azure",
"anthropic-chat": "anthropic",
"chat-google-generative-ai": "google",
"vertexai": "google-vertex",
"anthropic-chat-vertexai": "google-vertex-anthropic",
"amazon_bedrock_chat": "amazon-bedrock",
"amazon_bedrock_converse_chat": "amazon-bedrock",
"chat-ai21": "ai21",
"chat-deepseek": "deepseek",
"fireworks-chat": "fireworks-ai",
"groq-chat": "groq",
"huggingface-chat-wrapper": "huggingface",
"mistralai-chat": "mistral",
"chat-ollama": "ollama",
"perplexitychat": "perplexity",
"together-chat": "togetherai",
"upstage-chat": "upstage",
"xai-chat": "xai",
}


def _translate_provider_and_model_id(provider: str, model: str) -> tuple[str, str]:
"""Translate LangChain provider and model to models.dev equivalents.
ModelProfileRegistry = dict[str, ModelProfile]
"""Registry mapping model identifiers or names to their ModelProfile."""

Args:
provider: LangChain provider ID.
model: LangChain model ID.

Returns:
A tuple containing the models.dev provider ID and model ID.
"""
provider_id = _lc_type_to_provider_id.get(provider, provider)
def map_raw_data_to_profile(data: dict[str, Any]) -> ModelProfile:
"""Map raw model data to ModelProfile format.

if provider_id in ("google", "google-vertex"):
# convert models/gemini-2.0-flash-001 to gemini-2.0-flash
model_id = re.sub(r"-\d{3}$", "", model.replace("models/", ""))
elif provider_id == "amazon-bedrock":
# strip region prefixes like "us."
model_id = re.sub(r"^[A-Za-z]{2}\.", "", model)
else:
model_id = model

return provider_id, model_id


def get_model_profile(provider: str, model: str) -> ModelProfile | None:
"""Get the model capabilities for a given model.
This function is used by provider packages to convert raw data from models.dev
and augmentations into the standardized ModelProfile format.

Args:
provider: Identifier for provider (e.g., `'openai'`, `'anthropic'`).
model: Identifier for model (e.g., `'gpt-5'`,
`'claude-sonnet-4-5-20250929'`).
data: Raw model data from models.dev and augmentations.

Returns:
The model capabilities or `None` if not found in the data.
ModelProfile with standardized fields.
"""
if not provider or not model:
return None

provider_id, model_id = _translate_provider_and_model_id(provider, model)
data = _LOADER.get_profile_data(provider_id, model_id)
if not data:
# If either (1) provider not found or (2) model not found under matched provider
return None

# Map models.dev & augmentation fields -> ModelProfile fields
# See schema reference to see fields dropped: https://github.com/sst/models.dev?tab=readme-ov-file#schema-reference
# See schema reference: https://github.com/sst/models.dev?tab=readme-ov-file#schema-reference
profile = {
"max_input_tokens": data.get("limit", {}).get("context"),
"image_inputs": "image" in data.get("modalities", {}).get("input", []),
Expand Down
Loading
Loading