Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import cached_property
import json
import logging
from pathlib import Path
import re
from operator import itemgetter
import uuid
Expand All @@ -35,6 +36,10 @@
generate_from_stream,
agenerate_from_stream,
)
from langchain_core.language_models.profile import ModelProfile, ModelProfileRegistry
from langchain_core.language_models.profile._loader_utils import (
load_profiles_from_data_dir,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand Down Expand Up @@ -178,6 +183,17 @@
]


_MODEL_PROFILES = cast(
"ModelProfileRegistry",
load_profiles_from_data_dir(Path(__file__).parent / "data", "google-vertex"),
)


def _get_default_model_profile(model_name: str) -> ModelProfile:
default = _MODEL_PROFILES.get(model_name) or {}
return default.copy()


_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY = (
"__gemini_function_call_thought_signatures__"
)
Expand Down Expand Up @@ -1867,6 +1883,14 @@ def validate_environment(self) -> Self:

return self

@model_validator(mode="after")
def _set_model_profile(self) -> Self:
"""Set model profile if not overridden."""
if self.profile is None:
model_id = re.sub(r"-\d{3}$", "", self.model_name.replace("models/", ""))
self.profile = _get_default_model_profile(model_id)
return self

def _prepare_params(
self,
stop: list[str] | None = None,
Expand Down
Loading