Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 11 additions & 37 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import override

from langchain_core._api.beta_decorator import beta
from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
AsyncCallbackManager,
Expand All @@ -34,6 +33,7 @@
LangSmithParams,
LanguageModelInput,
)
from langchain_core.language_models.profile import ModelProfile
from langchain_core.load import dumpd, dumps
from langchain_core.messages import (
AIMessage,
Expand Down Expand Up @@ -76,8 +76,6 @@
if TYPE_CHECKING:
import uuid

from langchain_model_profiles import ModelProfile # type: ignore[import-untyped]

from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool
Expand Down Expand Up @@ -339,6 +337,16 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):

"""

profile: ModelProfile | None = Field(default=None, exclude=True)
"""Return profiling information for the model.

Profile data includes model capabilities such as context window sizes and
supported features. Data is automatically loaded from the provider package
if available.

Assign this attribute to override what the `profile` property returns.
"""

model_config = ConfigDict(
arbitrary_types_allowed=True,
)
Expand Down Expand Up @@ -1688,40 +1696,6 @@ class AnswerWithJustification(BaseModel):
return RunnableMap(raw=llm) | parser_with_fallback
return llm | output_parser

@property
@beta()
def profile(self) -> ModelProfile:
"""Return profiling information for the model.

This property relies on the `langchain-model-profiles` package to retrieve chat
model capabilities, such as context window sizes and supported features.

Raises:
ImportError: If `langchain-model-profiles` is not installed.

Returns:
A `ModelProfile` object containing profiling information for the model.
"""
try:
from langchain_model_profiles import get_model_profile # noqa: PLC0415
except ImportError as err:
informative_error_message = (
"To access model profiling information, please install the "
"`langchain-model-profiles` package: "
"`pip install langchain-model-profiles`."
)
raise ImportError(informative_error_message) from err

provider_id = self._llm_type
model_name = (
# Model name is not standardized across integrations. New integrations
# should prefer `model`.
getattr(self, "model", None)
or getattr(self, "model_name", None)
or getattr(self, "model_id", "")
)
return get_model_profile(provider_id, model_name) or {}


class SimpleChatModel(BaseChatModel):
"""Simplified implementation for a chat model to inherit from.
Expand Down
11 changes: 11 additions & 0 deletions libs/core/langchain_core/language_models/profile/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Model profile types and data loading utilities."""

from langchain_core.language_models.profile.model_profile import (
ModelProfile,
ModelProfileRegistry,
)

__all__ = [
"ModelProfile",
"ModelProfileRegistry",
]
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@ class _DataLoader:
augmentation structure and merge priority.
"""

def __init__(self) -> None:
"""Initialize the loader."""
self._data_dir = Path(__file__).parent / "data"
def __init__(self, data_dir: Path) -> None:
"""Initialize the loader.

Args:
data_dir: Path to the data directory containing models.json and
augmentations.
"""
self._data_dir = data_dir

@property
def _base_data_path(self) -> Path:
Expand Down Expand Up @@ -73,23 +78,28 @@ def _merged_data(self) -> dict[str, Any]:
return data

def _load_provider_augmentations(self) -> dict[str, dict[str, Any]]:
"""Load all provider-level augmentations.
"""Load provider-level augmentations from profiles.toml.

Returns:
`dict` mapping provider IDs to their augmentation data.
"""
augmentations: dict[str, dict[str, Any]] = {}
providers_dir = self._augmentations_dir / "providers"
profiles_file = self._augmentations_dir / "profiles.toml"

if not providers_dir.exists():
if not profiles_file.exists():
return augmentations

for toml_file in providers_dir.glob("*.toml"):
provider_id = toml_file.stem
with toml_file.open("rb") as f:
data = tomllib.load(f)
if "profile" in data:
augmentations[provider_id] = data["profile"]
with profiles_file.open("rb") as f:
data = tomllib.load(f)
if "profile" in data:
# Load all provider IDs from base data and apply augmentation to all
try:
with self._base_data_path.open("r") as base_f:
base_data = json.load(base_f)
for provider_id in base_data:
augmentations[provider_id] = data["profile"]
except (OSError, json.JSONDecodeError):
pass

return augmentations

Expand Down Expand Up @@ -121,14 +131,11 @@ def _load_model_augmentations(self) -> dict[str, dict[str, dict[str, Any]]]:

return augmentations

def get_profile_data(
self, provider_id: str, model_id: str
) -> dict[str, Any] | None:
"""Get merged profile data for a specific model.
def get_profile_data(self, provider_id: str) -> dict[str, Any] | None:
"""Get merged profile data for all models.

Args:
provider_id: The provider identifier.
model_id: The model identifier.

Returns:
Merged model data `dict` or `None` if not found.
Expand All @@ -137,5 +144,4 @@ def get_profile_data(
if provider is None:
return None

models = provider.get("models", {})
return models.get(model_id)
return provider.get("models", {})
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Utilities for loading model profiles from provider packages."""

from functools import lru_cache
from pathlib import Path

from langchain_core.language_models.profile._data_loader import _DataLoader
from langchain_core.language_models.profile.model_profile import (
ModelProfileRegistry,
map_raw_data_to_profile,
)


def load_profiles_from_data_dir(
data_dir: Path, provider_id: str
) -> ModelProfileRegistry | None:
"""Load model profiles from a provider's data directory.

Args:
data_dir: Path to the provider's data directory.
provider_id: The provider identifier (e.g., 'anthropic', 'openai').

Returns:
ModelProfile with model capabilities, or None if not found.
"""
loader = _get_loader(data_dir)
data = loader.get_profile_data(provider_id)
if not data:
return None
return {
model_name: map_raw_data_to_profile(raw_profile)
for model_name, raw_profile in data.items()
}


@lru_cache(maxsize=32)
def _get_loader(data_dir: Path) -> _DataLoader:
"""Get a cached loader for a data directory.

Args:
data_dir: Path to the data directory.

Returns:
DataLoader instance.
"""
return _DataLoader(data_dir)
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Model profiles package."""
"""Model profile types and utilities."""

import re
from typing import Any

from typing_extensions import TypedDict

from langchain_model_profiles._data_loader import _DataLoader


class ModelProfile(TypedDict, total=False):
"""Model profile."""
"""Model profile.

Provides information about chat model capabilities, such as context window sizes
and supported features.
"""

# --- Input constraints ---

Expand Down Expand Up @@ -39,10 +42,10 @@ class ModelProfile(TypedDict, total=False):
# TODO: add more detail about formats? e.g. bytes or base64

image_tool_message: bool
"""TODO: description."""
"""Whether images can be included in tool messages."""

pdf_tool_message: bool
"""TODO: description."""
"""Whether PDFs can be included in tool messages."""

# --- Output constraints ---

Expand Down Expand Up @@ -77,7 +80,9 @@ class ModelProfile(TypedDict, total=False):
feature"""


_LOADER = _DataLoader()
ModelProfileRegistry = dict[str, ModelProfile]
"""Registry mapping model identifiers or names to their ModelProfile."""


_lc_type_to_provider_id = {
"openai-chat": "openai",
Expand All @@ -102,6 +107,7 @@ class ModelProfile(TypedDict, total=False):
}


# TODO: delete this function
def _translate_provider_and_model_id(provider: str, model: str) -> tuple[str, str]:
"""Translate LangChain provider and model to models.dev equivalents.

Expand All @@ -126,28 +132,20 @@ def _translate_provider_and_model_id(provider: str, model: str) -> tuple[str, st
return provider_id, model_id


def get_model_profile(provider: str, model: str) -> ModelProfile | None:
"""Get the model capabilities for a given model.
def map_raw_data_to_profile(data: dict[str, Any]) -> ModelProfile:
"""Map raw model data to ModelProfile format.

This function is used by provider packages to convert raw data from models.dev
and augmentations into the standardized ModelProfile format.

Args:
provider: Identifier for provider (e.g., `'openai'`, `'anthropic'`).
model: Identifier for model (e.g., `'gpt-5'`,
`'claude-sonnet-4-5-20250929'`).
data: Raw model data from models.dev and augmentations.

Returns:
The model capabilities or `None` if not found in the data.
ModelProfile with standardized fields.
"""
if not provider or not model:
return None

provider_id, model_id = _translate_provider_and_model_id(provider, model)
data = _LOADER.get_profile_data(provider_id, model_id)
if not data:
# If either (1) provider not found or (2) model not found under matched provider
return None

# Map models.dev & augmentation fields -> ModelProfile fields
# See schema reference to see fields dropped: https://github.com/sst/models.dev?tab=readme-ov-file#schema-reference
# See schema reference: https://github.com/sst/models.dev?tab=readme-ov-file#schema-reference
profile = {
"max_input_tokens": data.get("limit", {}).get("context"),
"image_inputs": "image" in data.get("modalities", {}).get("input", []),
Expand Down
3 changes: 0 additions & 3 deletions libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ typing = [
"mypy>=1.18.1,<1.19.0",
"types-pyyaml>=6.0.12.2,<7.0.0.0",
"types-requests>=2.28.11.5,<3.0.0.0",
"langchain-model-profiles",
"langchain-text-splitters",
]
dev = [
Expand All @@ -58,15 +57,13 @@ test = [
"blockbuster>=1.5.18,<1.6.0",
"numpy>=1.26.4; python_version<'3.13'",
"numpy>=2.1.0; python_version>='3.13'",
"langchain-model-profiles",
"langchain-tests",
"pytest-benchmark",
"pytest-codspeed",
]
test_integration = []

[tool.uv.sources]
langchain-model-profiles = { path = "../model-profiles" }
langchain-tests = { path = "../standard-tests" }
langchain-text-splitters = { path = "../text-splitters" }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1222,19 +1222,12 @@ def _llm_type(self) -> str:

def test_model_profiles() -> None:
model = GenericFakeChatModel(messages=iter([]))
profile = model.profile
assert profile == {}
assert model.profile is None

class MyModel(GenericFakeChatModel):
model: str = "gpt-5"

@property
def _llm_type(self) -> str:
return "openai-chat"

model = MyModel(messages=iter([]))
profile = model.profile
assert profile
model_with_profile = GenericFakeChatModel(
messages=iter([]), profile={"max_input_tokens": 100}
)
assert model_with_profile.profile == {"max_input_tokens": 100}


class MockResponse:
Expand Down
Loading
Loading