diff --git a/src/fmcore/aws/constants/aws_constants.py b/src/fmcore/aws/constants/aws_constants.py index b187c92..4c003dc 100644 --- a/src/fmcore/aws/constants/aws_constants.py +++ b/src/fmcore/aws/constants/aws_constants.py @@ -13,6 +13,9 @@ CREDENTIALS: str = "Credentials" ERROR: str = "Error" EXPIRATION: str = "Expiration" +EXTENDED_THINKING_MODELS: list = [ + 'us.anthropic.claude-3-7-sonnet-20250219-v1:0' +] REGION: str = "region" REGION_NAME: str = "region_name" ROLE_ARN: str = "role_arn" diff --git a/src/fmcore/llm/bedrock_llm.py b/src/fmcore/llm/bedrock_llm.py index 5ca643b..ca59843 100644 --- a/src/fmcore/llm/bedrock_llm.py +++ b/src/fmcore/llm/bedrock_llm.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from langchain_core.messages import BaseMessage, BaseMessageChunk +from fmcore.aws.constants.aws_constants import EXTENDED_THINKING_MODELS from fmcore.aws.factory.bedrock_factory import BedrockFactory from fmcore.llm.base_llm import BaseLLM from fmcore.llm.types.llm_types import LLMConfig @@ -30,6 +31,45 @@ class BedrockLLM(BaseLLM[List[BaseMessage], BaseMessage, BaseMessageChunk], Base client: ChatBedrockConverse rate_limiter: AsyncLimiter + @classmethod + def _validate_bedrock_params(cls, llm_config: LLMConfig) -> None: + """ + Validates and adjusts thinking-related parameters for Bedrock models. + + Args: + llm_config: The LLM configuration to validate + + Raises: + ValueError: If token budget constraints are violated + """ + model_params = llm_config.model_params + + if llm_config.model_id in EXTENDED_THINKING_MODELS: + thinking_params = {} + + if (model_params.additional_model_request_fields is not None and + "thinking" in model_params.additional_model_request_fields): + thinking_params = model_params.additional_model_request_fields["thinking"] + + if "type" in thinking_params: + if thinking_params["type"] == "enabled": + # Fix temperature and unset top p for thinking mode + model_params.temperature = 1.0 + model_params.top_p = None + + # Validate token budgets + budget_tokens = thinking_params.get("budget_tokens", 0) + if budget_tokens < 1024: + raise ValueError("Budget tokens must be greater than 1024") + if budget_tokens > model_params.max_tokens: + raise ValueError("Max tokens must be greater than budget tokens") + elif "budget_tokens" in thinking_params: + # Remove budget tokens if thinking is disabled + model_params.additional_model_request_fields["thinking"].pop("budget_tokens") + else: + # Clear additional fields for non-thinking models + model_params.additional_model_request_fields = None + @classmethod def _get_instance(cls, *, llm_config: LLMConfig) -> "BedrockLLM": """ @@ -41,6 +81,7 @@ def _get_instance(cls, *, llm_config: LLMConfig) -> "BedrockLLM": Args: llm_config (SingleLLMConfig): Contains model_id, model_params, and provider_params. """ + cls._validate_bedrock_params(llm_config) converse_client = BedrockFactory.create_converse_client(llm_config=llm_config) rate_limiter = RateLimiterUtils.create_async_rate_limiter( rate_limit_config=llm_config.provider_params.rate_limit diff --git a/src/fmcore/llm/types/llm_types.py b/src/fmcore/llm/types/llm_types.py index ede4d20..0914f6a 100644 --- a/src/fmcore/llm/types/llm_types.py +++ b/src/fmcore/llm/types/llm_types.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, List, Dict +from typing import Union, Optional, List, Dict, Any from pydantic import model_validator, SerializeAsAny @@ -18,11 +18,13 @@ class ModelParams(MutableTyped): max_tokens (Optional[int]): Specifies the maximum number of tokens to generate in the response. top_p (Optional[float]): Enables nucleus sampling, where the model considers only the tokens comprising the top `p` cumulative probability mass. + additional_model_request_fields: Additional inference parameters that the model supports """ temperature: Optional[float] = 0.5 max_tokens: Optional[int] = 1024 top_p: Optional[float] = 0.5 + additional_model_request_fields: Optional[Dict[str, Any]] = None class LLMConfig(MutableTyped):