Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/fmcore/aws/constants/aws_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
CREDENTIALS: str = "Credentials"
ERROR: str = "Error"
EXPIRATION: str = "Expiration"
EXTENDED_THINKING_MODELS: list = [
'us.anthropic.claude-3-7-sonnet-20250219-v1:0'
]
REGION: str = "region"
REGION_NAME: str = "region_name"
ROLE_ARN: str = "role_arn"
Expand Down
41 changes: 41 additions & 0 deletions src/fmcore/llm/bedrock_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel
from langchain_core.messages import BaseMessage, BaseMessageChunk

from fmcore.aws.constants.aws_constants import EXTENDED_THINKING_MODELS
from fmcore.aws.factory.bedrock_factory import BedrockFactory
from fmcore.llm.base_llm import BaseLLM
from fmcore.llm.types.llm_types import LLMConfig
Expand All @@ -30,6 +31,45 @@ class BedrockLLM(BaseLLM[List[BaseMessage], BaseMessage, BaseMessageChunk], Base
client: ChatBedrockConverse
rate_limiter: AsyncLimiter

@classmethod
def _validate_bedrock_params(cls, llm_config: LLMConfig) -> None:
"""
Validates and adjusts thinking-related parameters for Bedrock models.

Args:
llm_config: The LLM configuration to validate

Raises:
ValueError: If token budget constraints are violated
"""
model_params = llm_config.model_params

if llm_config.model_id in EXTENDED_THINKING_MODELS:
thinking_params = {}

if (model_params.additional_model_request_fields is not None and
"thinking" in model_params.additional_model_request_fields):
thinking_params = model_params.additional_model_request_fields["thinking"]

if "type" in thinking_params:
if thinking_params["type"] == "enabled":
# Fix temperature and unset top p for thinking mode
model_params.temperature = 1.0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This transformation can be skipped - it'd silently update the values for compliance and can lead to unexpected results. If the input config is not supported, the user should be the one to fix it

model_params.top_p = None

# Validate token budgets
budget_tokens = thinking_params.get("budget_tokens", 0)
if budget_tokens < 1024:
raise ValueError("Budget tokens must be greater than 1024")
if budget_tokens > model_params.max_tokens:
raise ValueError("Max tokens must be greater than budget tokens")
elif "budget_tokens" in thinking_params:
# Remove budget tokens if thinking is disabled
model_params.additional_model_request_fields["thinking"].pop("budget_tokens")
else:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we avoid this else block? If a new bedrock model is launched that supports streaming, then you'd be removing the thinking params passed in the llm-config until EXTENDED_THINKING_MODELS is updated in this code

Not having this allows the values to pass through to converse-client creation, and if its an error on the application side if they send these params when they were not supposed to be, then the client creation will just fail - as expected

# Clear additional fields for non-thinking models
model_params.additional_model_request_fields = None

@classmethod
def _get_instance(cls, *, llm_config: LLMConfig) -> "BedrockLLM":
"""
Expand All @@ -41,6 +81,7 @@ def _get_instance(cls, *, llm_config: LLMConfig) -> "BedrockLLM":
Args:
llm_config (SingleLLMConfig): Contains model_id, model_params, and provider_params.
"""
cls._validate_bedrock_params(llm_config)
converse_client = BedrockFactory.create_converse_client(llm_config=llm_config)
rate_limiter = RateLimiterUtils.create_async_rate_limiter(
rate_limit_config=llm_config.provider_params.rate_limit
Expand Down
4 changes: 3 additions & 1 deletion src/fmcore/llm/types/llm_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Optional, List, Dict
from typing import Union, Optional, List, Dict, Any

from pydantic import model_validator, SerializeAsAny

Expand All @@ -18,11 +18,13 @@ class ModelParams(MutableTyped):
max_tokens (Optional[int]): Specifies the maximum number of tokens to generate in the response.
top_p (Optional[float]): Enables nucleus sampling, where the model considers
only the tokens comprising the top `p` cumulative probability mass.
additional_model_request_fields: Additional inference parameters that the model supports
"""

temperature: Optional[float] = 0.5
max_tokens: Optional[int] = 1024
top_p: Optional[float] = 0.5
additional_model_request_fields: Optional[Dict[str, Any]] = None


class LLMConfig(MutableTyped):
Expand Down