Skip to content

Commit 831acd6

Browse files
committed
fmt
1 parent f6af5a1 commit 831acd6

File tree

3 files changed

+46
-1
lines changed

3 files changed

+46
-1
lines changed

libs/aws/langchain_aws/chat_models/bedrock_converse.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from langchain_core._api import beta
2323
from langchain_core.callbacks import CallbackManagerForLLMRun
2424
from langchain_core.language_models import BaseChatModel, LanguageModelInput
25+
from langchain_core.language_models.chat_models import LangSmithParams
2526
from langchain_core.messages import (
2627
AIMessage,
2728
BaseMessage,
@@ -476,6 +477,23 @@ def _converse_params(
476477
}
477478
)
478479

480+
def _get_ls_params(
481+
self, stop: Optional[List[str]] = None, **kwargs: Any
482+
) -> LangSmithParams:
483+
"""Get standard params for tracing."""
484+
params = self._get_invocation_params(stop=stop, **kwargs)
485+
ls_params = LangSmithParams(
486+
ls_provider="amazon_bedrock",
487+
ls_model_name=self.model_id,
488+
ls_model_type="chat",
489+
ls_temperature=params.get("temperature", self.temperature),
490+
)
491+
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
492+
ls_params["ls_max_tokens"] = ls_max_tokens
493+
if ls_stop := stop or params.get("stop", None):
494+
ls_params["ls_stop"] = ls_stop
495+
return ls_params
496+
479497
@property
480498
def _llm_type(self) -> str:
481499
"""Return type of chat model."""

libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,40 @@
11
"""Test chat model integration."""
22

3-
from typing import cast
3+
from typing import Type, cast
44

5+
import pytest
6+
from langchain_core.language_models import BaseChatModel
57
from langchain_core.pydantic_v1 import BaseModel, Field
68
from langchain_core.runnables import RunnableBinding
9+
from langchain_standard_tests.unit_tests import ChatModelUnitTests
710

811
from langchain_aws import ChatBedrockConverse
912

1013

14+
class TestBedrockStandard(ChatModelUnitTests):
15+
@pytest.fixture
16+
def chat_model_class(self) -> Type[BaseChatModel]:
17+
return ChatBedrockConverse
18+
19+
@pytest.fixture
20+
def chat_model_params(self) -> dict:
21+
return {
22+
"model_id": "anthropic.claude-3-sonnet-20240229-v1:0",
23+
}
24+
25+
@pytest.mark.xfail()
26+
def test_chat_model_init_api_key(
27+
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
28+
) -> None:
29+
super().test_chat_model_init_api_key(chat_model_class, chat_model_params)
30+
31+
@pytest.mark.xfail()
32+
def test_chat_model_init_streaming(
33+
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
34+
) -> None:
35+
super().test_chat_model_init_streaming(chat_model_class, chat_model_params)
36+
37+
1138
class GetWeather(BaseModel):
1239
"""Get the current weather in a given location"""
1340

0 commit comments

Comments
 (0)