Skip to content

Commit 4d8eec9

Browse files
authored
refactor: enhance model validation and provider inference in LLM class (#3976)
* refactor: enhance model validation and provider inference in LLM class - Updated the model validation logic to support pattern matching for new models and "latest" versions, improving flexibility for various providers. - Refactored the `_validate_model_in_constants` method to first check hardcoded constants and then fall back to pattern matching. - Introduced `_matches_provider_pattern` to streamline provider-specific model checks. - Enhanced the `_infer_provider_from_model` method to utilize pattern matching for better provider inference. This refactor aims to improve the extensibility of the LLM class, allowing it to accommodate new models without requiring constant updates to the hardcoded lists. * feat: add new Anthropic model versions to constants - Introduced "claude-opus-4-5-20251101" and "claude-opus-4-5" to the AnthropicModels and ANTHROPIC_MODELS lists for enhanced model support. - Added "anthropic.claude-opus-4-5-20251101-v1:0" to BedrockModels and BEDROCK_MODELS to ensure compatibility with the latest model offerings. - Updated test cases to ensure proper environment variable handling for model validation, improving robustness in testing scenarios. * dont infer this way - dropped
1 parent 2025a26 commit 4d8eec9

File tree

3 files changed

+128
-27
lines changed

3 files changed

+128
-27
lines changed

lib/crewai/src/crewai/llm.py

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -407,45 +407,99 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
407407
return instance
408408

409409
@classmethod
410-
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
411-
"""Validate if a model name exists in the provider's constants.
410+
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
411+
"""Check if a model name matches provider-specific patterns.
412+
413+
This allows supporting models that aren't in the hardcoded constants list,
414+
including "latest" versions and new models that follow provider naming conventions.
412415
413416
Args:
414-
model: The model name to validate
417+
model: The model name to check
415418
provider: The provider to check against (canonical name)
416419
417420
Returns:
418-
True if the model exists in the provider's constants, False otherwise
421+
True if the model matches the provider's naming pattern, False otherwise
419422
"""
423+
model_lower = model.lower()
424+
420425
if provider == "openai":
421-
return model in OPENAI_MODELS
426+
return any(
427+
model_lower.startswith(prefix)
428+
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
429+
)
422430

423431
if provider == "anthropic" or provider == "claude":
424-
return model in ANTHROPIC_MODELS
432+
return any(
433+
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
434+
)
425435

426-
if provider == "gemini":
427-
return model in GEMINI_MODELS
436+
if provider == "gemini" or provider == "google":
437+
return any(
438+
model_lower.startswith(prefix)
439+
for prefix in ["gemini-", "gemma-", "learnlm-"]
440+
)
428441

429442
if provider == "bedrock":
430-
return model in BEDROCK_MODELS
443+
return "." in model_lower
444+
445+
if provider == "azure":
446+
return any(
447+
model_lower.startswith(prefix)
448+
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
449+
)
450+
451+
return False
452+
453+
@classmethod
454+
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
455+
"""Validate if a model name exists in the provider's constants or matches provider patterns.
456+
457+
This method first checks the hardcoded constants list for known models.
458+
If not found, it falls back to pattern matching to support new models,
459+
"latest" versions, and models that follow provider naming conventions.
460+
461+
Args:
462+
model: The model name to validate
463+
provider: The provider to check against (canonical name)
464+
465+
Returns:
466+
True if the model exists in constants or matches provider patterns, False otherwise
467+
"""
468+
if provider == "openai" and model in OPENAI_MODELS:
469+
return True
470+
471+
if (
472+
provider == "anthropic" or provider == "claude"
473+
) and model in ANTHROPIC_MODELS:
474+
return True
475+
476+
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
477+
return True
478+
479+
if provider == "bedrock" and model in BEDROCK_MODELS:
480+
return True
431481

432482
if provider == "azure":
433483
# azure does not provide a list of available models, determine a better way to handle this
434484
return True
435485

436-
return False
486+
# Fallback to pattern matching for models not in constants
487+
return cls._matches_provider_pattern(model, provider)
437488

438489
@classmethod
439490
def _infer_provider_from_model(cls, model: str) -> str:
440491
"""Infer the provider from the model name.
441492
493+
This method first checks the hardcoded constants list for known models.
494+
If not found, it uses pattern matching to infer the provider from model name patterns.
495+
This allows supporting new models and "latest" versions without hardcoding.
496+
442497
Args:
443498
model: The model name without provider prefix
444499
445500
Returns:
446501
The inferred provider name, defaults to "openai"
447502
"""
448-
449503
if model in OPENAI_MODELS:
450504
return "openai"
451505

@@ -1699,12 +1753,14 @@ def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM:
16991753
max_tokens=self.max_tokens,
17001754
presence_penalty=self.presence_penalty,
17011755
frequency_penalty=self.frequency_penalty,
1702-
logit_bias=copy.deepcopy(self.logit_bias, memo)
1703-
if self.logit_bias
1704-
else None,
1705-
response_format=copy.deepcopy(self.response_format, memo)
1706-
if self.response_format
1707-
else None,
1756+
logit_bias=(
1757+
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
1758+
),
1759+
response_format=(
1760+
copy.deepcopy(self.response_format, memo)
1761+
if self.response_format
1762+
else None
1763+
),
17081764
seed=self.seed,
17091765
logprobs=self.logprobs,
17101766
top_logprobs=self.top_logprobs,

lib/crewai/src/crewai/llms/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@
182182

183183

184184
AnthropicModels: TypeAlias = Literal[
185+
"claude-opus-4-5-20251101",
186+
"claude-opus-4-5",
185187
"claude-3-7-sonnet-latest",
186188
"claude-3-7-sonnet-20250219",
187189
"claude-3-5-haiku-latest",
@@ -208,6 +210,8 @@
208210
"claude-3-haiku-20240307",
209211
]
210212
ANTHROPIC_MODELS: list[AnthropicModels] = [
213+
"claude-opus-4-5-20251101",
214+
"claude-opus-4-5",
211215
"claude-3-7-sonnet-latest",
212216
"claude-3-7-sonnet-20250219",
213217
"claude-3-5-haiku-latest",
@@ -452,6 +456,7 @@
452456
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
453457
"anthropic.claude-haiku-4-5-20251001-v1:0",
454458
"anthropic.claude-instant-v1:2:100k",
459+
"anthropic.claude-opus-4-5-20251101-v1:0",
455460
"anthropic.claude-opus-4-1-20250805-v1:0",
456461
"anthropic.claude-opus-4-20250514-v1:0",
457462
"anthropic.claude-sonnet-4-20250514-v1:0",
@@ -524,6 +529,7 @@
524529
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
525530
"anthropic.claude-haiku-4-5-20251001-v1:0",
526531
"anthropic.claude-instant-v1:2:100k",
532+
"anthropic.claude-opus-4-5-20251101-v1:0",
527533
"anthropic.claude-opus-4-1-20250805-v1:0",
528534
"anthropic.claude-opus-4-20250514-v1:0",
529535
"anthropic.claude-sonnet-4-20250514-v1:0",

lib/crewai/tests/test_llm.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,11 @@ class DummyResponse(BaseModel):
243243

244244
# Patch supports_response_schema to simulate an unsupported model.
245245
with patch("crewai.llm.supports_response_schema", return_value=False):
246-
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
246+
llm = LLM(
247+
model="gemini/gemini-1.5-pro",
248+
response_format=DummyResponse,
249+
is_litellm=True,
250+
)
247251
with pytest.raises(ValueError) as excinfo:
248252
llm._validate_call_params()
249253
assert "does not support response_format" in str(excinfo.value)
@@ -702,13 +706,16 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
702706

703707
assert formatted == original_messages
704708

709+
705710
def test_native_provider_raises_error_when_supported_but_fails():
706711
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
707712
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
708713
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native:
709714
# Mock that provider exists but throws an error when instantiated
710715
mock_provider = MagicMock()
711-
mock_provider.side_effect = ValueError("Native provider initialization failed")
716+
mock_provider.side_effect = ValueError(
717+
"Native provider initialization failed"
718+
)
712719
mock_get_native.return_value = mock_provider
713720

714721
with pytest.raises(ImportError) as excinfo:
@@ -751,23 +758,38 @@ def test_prefixed_models_with_valid_constants_use_native_sdk():
751758

752759

753760
def test_prefixed_models_with_invalid_constants_use_litellm():
754-
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants."""
761+
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns."""
755762
# Test openai/ prefix with non-OpenAI model (not in OPENAI_MODELS) → LiteLLM
756763
llm = LLM(model="openai/gemini-2.5-flash", is_litellm=False)
757764
assert llm.is_litellm is True
758765
assert llm.model == "openai/gemini-2.5-flash"
759766

760-
# Test openai/ prefix with unknown future model → LiteLLM
761-
llm2 = LLM(model="openai/gpt-future-6", is_litellm=False)
767+
# Test openai/ prefix with model that doesn't match patterns (e.g. no gpt- prefix) → LiteLLM
768+
llm2 = LLM(model="openai/custom-finetune-model", is_litellm=False)
762769
assert llm2.is_litellm is True
763-
assert llm2.model == "openai/gpt-future-6"
770+
assert llm2.model == "openai/custom-finetune-model"
764771

765772
# Test anthropic/ prefix with non-Anthropic model → LiteLLM
766773
llm3 = LLM(model="anthropic/gpt-4o", is_litellm=False)
767774
assert llm3.is_litellm is True
768775
assert llm3.model == "anthropic/gpt-4o"
769776

770777

778+
def test_prefixed_models_with_valid_patterns_use_native_sdk():
779+
"""Test that models matching provider patterns use native SDK even if not in constants."""
780+
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
781+
llm = LLM(model="openai/gpt-future-6", is_litellm=False)
782+
assert llm.is_litellm is False
783+
assert llm.provider == "openai"
784+
assert llm.model == "gpt-future-6"
785+
786+
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
787+
llm2 = LLM(model="anthropic/claude-future-5", is_litellm=False)
788+
assert llm2.is_litellm is False
789+
assert llm2.provider == "anthropic"
790+
assert llm2.model == "claude-future-5"
791+
792+
771793
def test_prefixed_models_with_non_native_providers_use_litellm():
772794
"""Test that models with non-native provider prefixes always use LiteLLM."""
773795
# Test groq/ prefix (not a native provider) → LiteLLM
@@ -821,19 +843,36 @@ def test_validate_model_in_constants():
821843
"""Test the _validate_model_in_constants method."""
822844
# OpenAI models
823845
assert LLM._validate_model_in_constants("gpt-4o", "openai") is True
824-
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is False
846+
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is True
847+
assert LLM._validate_model_in_constants("o1-latest", "openai") is True
848+
assert LLM._validate_model_in_constants("unknown-model", "openai") is False
825849

826850
# Anthropic models
827851
assert LLM._validate_model_in_constants("claude-opus-4-0", "claude") is True
828-
assert LLM._validate_model_in_constants("claude-future-5", "claude") is False
852+
assert LLM._validate_model_in_constants("claude-future-5", "claude") is True
853+
assert (
854+
LLM._validate_model_in_constants("claude-3-5-sonnet-latest", "claude") is True
855+
)
856+
assert LLM._validate_model_in_constants("unknown-model", "claude") is False
829857

830858
# Gemini models
831859
assert LLM._validate_model_in_constants("gemini-2.5-pro", "gemini") is True
832-
assert LLM._validate_model_in_constants("gemini-future", "gemini") is False
860+
assert LLM._validate_model_in_constants("gemini-future", "gemini") is True
861+
assert LLM._validate_model_in_constants("gemma-3-latest", "gemini") is True
862+
assert LLM._validate_model_in_constants("unknown-model", "gemini") is False
833863

834864
# Azure models
835865
assert LLM._validate_model_in_constants("gpt-4o", "azure") is True
836866
assert LLM._validate_model_in_constants("gpt-35-turbo", "azure") is True
837867

838868
# Bedrock models
839-
assert LLM._validate_model_in_constants("anthropic.claude-opus-4-1-20250805-v1:0", "bedrock") is True
869+
assert (
870+
LLM._validate_model_in_constants(
871+
"anthropic.claude-opus-4-1-20250805-v1:0", "bedrock"
872+
)
873+
is True
874+
)
875+
assert (
876+
LLM._validate_model_in_constants("anthropic.claude-future-v1:0", "bedrock")
877+
is True
878+
)

0 commit comments

Comments
 (0)