Skip to content

Commit 8dca1b4

Browse files
fix: remove api_url reference from Watson embedding callable
1 parent fa8b159 commit 8dca1b4

File tree

21 files changed

+268
-85
lines changed

21 files changed

+268
-85
lines changed

src/crewai/rag/core/base_embeddings_provider.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Generic, TypeVar
44

55
from pydantic import Field
6-
from pydantic_settings import BaseSettings
6+
from pydantic_settings import BaseSettings, SettingsConfigDict
77

88
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
99

@@ -17,6 +17,7 @@ class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
1717
embedding functions from various providers.
1818
"""
1919

20+
model_config = SettingsConfigDict(extra="allow", populate_by_name=True)
2021
embedding_callable: type[T] = Field(
2122
..., description="The embedding function class to use"
2223
)

src/crewai/rag/embeddings/providers/aws/bedrock.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
5151
model_name: str = Field(
5252
default="amazon.titan-embed-text-v1",
5353
description="Model name to use for embeddings",
54+
validation_alias="BEDROCK_MODEL_NAME",
5455
)
5556
session: Session = Field(
5657
default_factory=create_aws_session, description="AWS session object"

src/crewai/rag/embeddings/providers/cohere/cohere_provider.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
1414
embedding_callable: type[CohereEmbeddingFunction] = Field(
1515
default=CohereEmbeddingFunction, description="Cohere embedding function class"
1616
)
17-
api_key: str = Field(description="Cohere API key", alias="COHERE_API_KEY")
17+
api_key: str = Field(
18+
description="Cohere API key", validation_alias="COHERE_API_KEY"
19+
)
1820
model_name: str = Field(
19-
default="large", description="Model name to use for embeddings"
21+
default="large",
22+
description="Model name to use for embeddings",
23+
validation_alias="COHERE_MODEL_NAME",
2024
)

src/crewai/rag/embeddings/providers/google/generative_ai.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,15 @@ class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFun
1616
description="Google Generative AI embedding function class",
1717
)
1818
model_name: str = Field(
19-
default="models/embedding-001", description="Model name to use for embeddings"
19+
default="models/embedding-001",
20+
description="Model name to use for embeddings",
21+
validation_alias="GOOGLE_GENERATIVE_AI_MODEL_NAME",
22+
)
23+
api_key: str = Field(
24+
description="Google API key", validation_alias="GOOGLE_API_KEY"
2025
)
21-
api_key: str = Field(description="Google API key", alias="GOOGLE_API_KEY")
2226
task_type: str = Field(
23-
default="RETRIEVAL_DOCUMENT", description="Task type for embeddings"
27+
default="RETRIEVAL_DOCUMENT",
28+
description="Task type for embeddings",
29+
validation_alias="GOOGLE_GENERATIVE_AI_TASK_TYPE",
2430
)

src/crewai/rag/embeddings/providers/google/vertex.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,20 @@ class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
1616
description="Vertex AI embedding function class",
1717
)
1818
model_name: str = Field(
19-
default="textembedding-gecko", description="Model name to use for embeddings"
19+
default="textembedding-gecko",
20+
description="Model name to use for embeddings",
21+
validation_alias="GOOGLE_VERTEX_MODEL_NAME",
22+
)
23+
api_key: str = Field(
24+
description="Google API key", validation_alias="GOOGLE_CLOUD_API_KEY"
2025
)
21-
api_key: str = Field(description="Google API key", alias="GOOGLE_CLOUD_API_KEY")
2226
project_id: str = Field(
2327
default="cloud-large-language-models",
2428
description="GCP project ID",
25-
alias="GOOGLE_CLOUD_PROJECT",
29+
validation_alias="GOOGLE_CLOUD_PROJECT",
2630
)
2731
region: str = Field(
28-
default="us-central1", description="GCP region", alias="GOOGLE_CLOUD_REGION"
32+
default="us-central1",
33+
description="GCP region",
34+
validation_alias="GOOGLE_CLOUD_REGION",
2935
)

src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,6 @@ class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
1515
default=HuggingFaceEmbeddingServer,
1616
description="HuggingFace embedding function class",
1717
)
18-
url: str = Field(description="HuggingFace API URL")
18+
url: str = Field(
19+
description="HuggingFace API URL", validation_alias="HUGGINGFACE_URL"
20+
)

src/crewai/rag/embeddings/providers/ibm/embedding_callable.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ def __call__(self, input: Documents) -> Embeddings:
7272
embeddings_config["credentials"] = self._config["credentials"]
7373
else:
7474
cred_config: dict = {}
75-
if "api_url" in self._config and self._config["api_url"] is not None:
76-
cred_config["url"] = self._config["api_url"]
7775
if "url" in self._config and self._config["url"] is not None:
7876
cred_config["url"] = self._config["url"]
7977
if "api_key" in self._config and self._config["api_key"] is not None:

src/crewai/rag/embeddings/providers/ibm/types.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
from typing import Annotated, Any, Literal, TypedDict
44

5-
from typing_extensions import Required
6-
75

86
class WatsonProviderConfig(TypedDict, total=False):
97
"""Configuration for Watson provider."""
108

11-
model_id: Required[str]
9+
model_id: str
10+
url: str
1211
params: dict[str, str | dict[str, str]]
1312
credentials: Any
1413
project_id: str
@@ -21,8 +20,6 @@ class WatsonProviderConfig(TypedDict, total=False):
2120
max_retries: int
2221
delay_time: float
2322
retry_status_codes: list[int]
24-
api_url: str
25-
url: str
2623
api_key: str
2724
name: str
2825
iam_serviceid_crn: str

src/crewai/rag/embeddings/providers/ibm/watson.py

Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
APIClient,
55
Credentials,
66
)
7-
from pydantic import Field
7+
from pydantic import Field, model_validator
8+
from typing_extensions import Self
89

910
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
1011
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
@@ -21,44 +22,105 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
2122
embedding_callable: type[WatsonEmbeddingFunction] = Field(
2223
default=WatsonEmbeddingFunction, description="Watson embedding function class"
2324
)
24-
model_id: str = Field(description="Watson model ID")
25+
model_id: str = Field(
26+
description="Watson model ID", validation_alias="WATSON_MODEL_ID"
27+
)
2528
params: dict[str, str | dict[str, str]] | None = Field(
2629
default=None, description="Additional parameters"
2730
)
2831
credentials: Credentials | None = Field(
2932
default=None, description="Watson credentials"
3033
)
31-
project_id: str | None = Field(default=None, description="Watson project ID")
32-
space_id: str | None = Field(default=None, description="Watson space ID")
34+
project_id: str | None = Field(
35+
default=None,
36+
description="Watson project ID",
37+
validation_alias="WATSON_PROJECT_ID",
38+
)
39+
space_id: str | None = Field(
40+
default=None, description="Watson space ID", validation_alias="WATSON_SPACE_ID"
41+
)
3342
api_client: APIClient | None = Field(default=None, description="Watson API client")
34-
verify: bool | str | None = Field(default=None, description="SSL verification")
43+
verify: bool | str | None = Field(
44+
default=None, description="SSL verification", validation_alias="WATSON_VERIFY"
45+
)
3546
persistent_connection: bool = Field(
36-
default=True, description="Use persistent connection"
47+
default=True,
48+
description="Use persistent connection",
49+
validation_alias="WATSON_PERSISTENT_CONNECTION",
50+
)
51+
batch_size: int = Field(
52+
default=100,
53+
description="Batch size for processing",
54+
validation_alias="WATSON_BATCH_SIZE",
55+
)
56+
concurrency_limit: int = Field(
57+
default=10,
58+
description="Concurrency limit",
59+
validation_alias="WATSON_CONCURRENCY_LIMIT",
60+
)
61+
max_retries: int | None = Field(
62+
default=None,
63+
description="Maximum retries",
64+
validation_alias="WATSON_MAX_RETRIES",
3765
)
38-
batch_size: int = Field(default=100, description="Batch size for processing")
39-
concurrency_limit: int = Field(default=10, description="Concurrency limit")
40-
max_retries: int | None = Field(default=None, description="Maximum retries")
4166
delay_time: float | None = Field(
42-
default=None, description="Delay time between retries"
67+
default=None,
68+
description="Delay time between retries",
69+
validation_alias="WATSON_DELAY_TIME",
4370
)
4471
retry_status_codes: list[int] | None = Field(
4572
default=None, description="HTTP status codes to retry on"
4673
)
47-
url: str | None = Field(default=None, description="Watson API URL")
48-
api_key: str | None = Field(default=None, description="Watson API key")
49-
name: str | None = Field(default=None, description="Service name")
74+
url: str = Field(description="Watson API URL", validation_alias="WATSON_URL")
75+
api_key: str = Field(
76+
description="Watson API key", validation_alias="WATSON_API_KEY"
77+
)
78+
name: str | None = Field(
79+
default=None, description="Service name", validation_alias="WATSON_NAME"
80+
)
5081
iam_serviceid_crn: str | None = Field(
51-
default=None, description="IAM service ID CRN"
82+
default=None,
83+
description="IAM service ID CRN",
84+
validation_alias="WATSON_IAM_SERVICEID_CRN",
5285
)
5386
trusted_profile_id: str | None = Field(
54-
default=None, description="Trusted profile ID"
55-
)
56-
token: str | None = Field(default=None, description="Bearer token")
57-
projects_token: str | None = Field(default=None, description="Projects token")
58-
username: str | None = Field(default=None, description="Username")
59-
password: str | None = Field(default=None, description="Password")
60-
instance_id: str | None = Field(default=None, description="Service instance ID")
61-
version: str | None = Field(default=None, description="API version")
62-
bedrock_url: str | None = Field(default=None, description="Bedrock URL")
63-
platform_url: str | None = Field(default=None, description="Platform URL")
87+
default=None,
88+
description="Trusted profile ID",
89+
validation_alias="WATSON_TRUSTED_PROFILE_ID",
90+
)
91+
token: str | None = Field(
92+
default=None, description="Bearer token", validation_alias="WATSON_TOKEN"
93+
)
94+
projects_token: str | None = Field(
95+
default=None,
96+
description="Projects token",
97+
validation_alias="WATSON_PROJECTS_TOKEN",
98+
)
99+
username: str | None = Field(
100+
default=None, description="Username", validation_alias="WATSON_USERNAME"
101+
)
102+
password: str | None = Field(
103+
default=None, description="Password", validation_alias="WATSON_PASSWORD"
104+
)
105+
instance_id: str | None = Field(
106+
default=None,
107+
description="Service instance ID",
108+
validation_alias="WATSON_INSTANCE_ID",
109+
)
110+
version: str | None = Field(
111+
default=None, description="API version", validation_alias="WATSON_VERSION"
112+
)
113+
bedrock_url: str | None = Field(
114+
default=None, description="Bedrock URL", validation_alias="WATSON_BEDROCK_URL"
115+
)
116+
platform_url: str | None = Field(
117+
default=None, description="Platform URL", validation_alias="WATSON_PLATFORM_URL"
118+
)
64119
proxies: dict | None = Field(default=None, description="Proxy configuration")
120+
121+
@model_validator(mode="after")
122+
def validate_space_or_project(self) -> Self:
123+
"""Validate that either space_id or project_id is provided."""
124+
if not self.space_id and not self.project_id:
125+
raise ValueError("One of 'space_id' or 'project_id' must be provided")
126+
return self

src/crewai/rag/embeddings/providers/instructor/instructor_provider.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,17 @@ class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
1616
description="Instructor embedding function class",
1717
)
1818
model_name: str = Field(
19-
default="hkunlp/instructor-base", description="Model name to use"
19+
default="hkunlp/instructor-base",
20+
description="Model name to use",
21+
validation_alias="INSTRUCTOR_MODEL_NAME",
2022
)
2123
device: str = Field(
22-
default="cpu", description="Device to run model on (cpu or cuda)"
24+
default="cpu",
25+
description="Device to run model on (cpu or cuda)",
26+
validation_alias="INSTRUCTOR_DEVICE",
2327
)
2428
instruction: str | None = Field(
25-
default=None, description="Instruction for embeddings"
29+
default=None,
30+
description="Instruction for embeddings",
31+
validation_alias="INSTRUCTOR_INSTRUCTION",
2632
)

0 commit comments

Comments
 (0)