Skip to content

Commit 364fc9b

Browse files
authored
fix: change bedrock embed_model_name to embedder_model_name (#385)
1 parent 4edcd1b commit 364fc9b

File tree

6 files changed

+20
-16
lines changed

6 files changed

+20
-16
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## 0.5.3-dev2
1+
## 0.5.3
22

33
### Enhancements
44

@@ -7,6 +7,8 @@
77

88
### Fixes
99

10+
* **Fix bedrock embedder: rename embed_model_name to embedder_model_name**
11+
1012
## 0.5.2
1113

1214
### Enhancements

test/unit/v2/embedders/test_bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def generate_embedder_config_params() -> dict:
1515
"region_name": fake.city(),
1616
}
1717
if random.random() < 0.5:
18-
params["embed_model_name"] = fake.word()
18+
params["embedder_model_name"] = fake.word()
1919
return params
2020

2121

test/unit/v2/embedders/test_huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def generate_embedder_config_params() -> dict:
1717
params = {}
1818
if random.random() < 0.5:
19-
params["embed_model_name"] = fake.word() if random.random() < 0.5 else None
19+
params["embedder_model_name"] = fake.word() if random.random() < 0.5 else None
2020
params["embedder_model_kwargs"] = (
2121
generate_random_dictionary(key_type=str, value_type=Any)
2222
if random.random() < 0.5

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.3-dev2" # pragma: no cover
1+
__version__ = "0.5.3" # pragma: no cover

unstructured_ingest/embed/bedrock.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class BedrockEmbeddingConfig(EmbeddingConfig):
5757
aws_access_key_id: SecretStr
5858
aws_secret_access_key: SecretStr
5959
region_name: str = "us-west-2"
60-
embed_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
60+
embedder_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
6161

6262
def wrap_error(self, e: Exception) -> Exception:
6363
if is_internal_error(e=e):
@@ -130,15 +130,15 @@ def wrap_error(self, e: Exception) -> Exception:
130130

131131
def embed_query(self, query: str) -> list[float]:
132132
"""Call out to Bedrock embedding endpoint."""
133-
provider = self.config.embed_model_name.split(".")[0]
133+
provider = self.config.embedder_model_name.split(".")[0]
134134
body = conform_query(query=query, provider=provider)
135135

136136
bedrock_client = self.config.get_client()
137137
# invoke bedrock API
138138
try:
139139
response = bedrock_client.invoke_model(
140140
body=json.dumps(body),
141-
modelId=self.config.embed_model_name,
141+
modelId=self.config.embedder_model_name,
142142
accept="application/json",
143143
contentType="application/json",
144144
)
@@ -173,15 +173,15 @@ def wrap_error(self, e: Exception) -> Exception:
173173

174174
async def embed_query(self, query: str) -> list[float]:
175175
"""Call out to Bedrock embedding endpoint."""
176-
provider = self.config.embed_model_name.split(".")[0]
176+
provider = self.config.embedder_model_name.split(".")[0]
177177
body = conform_query(query=query, provider=provider)
178178
try:
179179
async with self.config.get_async_client() as bedrock_client:
180180
# invoke bedrock API
181181
try:
182182
response = await bedrock_client.invoke_model(
183183
body=json.dumps(body),
184-
modelId=self.config.embed_model_name,
184+
modelId=self.config.embedder_model_name,
185185
accept="application/json",
186186
contentType="application/json",
187187
)

unstructured_ingest/v2/processes/embedder.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,20 @@ def get_octoai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
9292

9393
return OctoAIEmbeddingEncoder(config=OctoAiEmbeddingConfig.model_validate(embedding_kwargs))
9494

95-
def get_bedrock_embedder(self) -> "BaseEmbeddingEncoder":
95+
def get_bedrock_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
9696
from unstructured_ingest.embed.bedrock import (
9797
BedrockEmbeddingConfig,
9898
BedrockEmbeddingEncoder,
9999
)
100100

101+
embedding_kwargs = embedding_kwargs | {
102+
"aws_access_key_id": self.embedding_aws_access_key_id,
103+
"aws_secret_access_key": self.embedding_aws_secret_access_key.get_secret_value(),
104+
"region_name": self.embedding_aws_region,
105+
}
106+
101107
return BedrockEmbeddingEncoder(
102-
config=BedrockEmbeddingConfig(
103-
aws_access_key_id=self.embedding_aws_access_key_id,
104-
aws_secret_access_key=self.embedding_aws_secret_access_key.get_secret_value(),
105-
region_name=self.embedding_aws_region,
106-
)
108+
config=BedrockEmbeddingConfig.model_validate(embedding_kwargs)
107109
)
108110

109111
def get_vertexai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
@@ -163,7 +165,7 @@ def get_embedder(self) -> "BaseEmbeddingEncoder":
163165
return self.get_octoai_embedder(embedding_kwargs=kwargs)
164166

165167
if self.embedding_provider == "bedrock":
166-
return self.get_bedrock_embedder()
168+
return self.get_bedrock_embedder(embedding_kwargs=kwargs)
167169

168170
if self.embedding_provider == "vertexai":
169171
return self.get_vertexai_embedder(embedding_kwargs=kwargs)

0 commit comments

Comments
 (0)