Skip to content

Commit 5288def

Browse files
authored
FixBug: Inference profile isn't working for rerank model (#2141)
1 parent c13cd14 commit 5288def

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

models/bedrock/manifest.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
version: 0.0.51
1+
version: 0.0.52
22
type: plugin
33
author: langgenius
44
name: bedrock

models/bedrock/models/rerank/rerank.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22
import logging
3+
import json
34

45
from botocore.exceptions import ClientError
56

@@ -61,8 +62,7 @@ def _invoke(
6162
return RerankResult(model=model, docs=docs)
6263

6364
# initialize client
64-
bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials)
65-
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
65+
bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials)
6666
text_sources = []
6767
for text in docs:
6868
text_sources.append(
@@ -76,6 +76,7 @@ def _invoke(
7676
},
7777
}
7878
)
79+
7980
# Check if using inference profile
8081
model_id = model
8182
inference_profile_id = credentials.get("inference_profile_id")
@@ -93,27 +94,37 @@ def _invoke(
9394
if not region:
9495
raise InvokeBadRequestError("aws_region is required in credentials")
9596
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{model_id}"
96-
rerankingConfiguration = {
97-
"type": "BEDROCK_RERANKING_MODEL",
98-
"bedrockRerankingConfiguration": {
99-
"numberOfResults": min(len(text_sources) if top_n is None else top_n, len(text_sources)),
100-
"modelConfiguration": {
101-
"modelArn": model_package_arn,
102-
},
103-
},
97+
98+
numberOfResults = min(len(text_sources) if top_n is None else top_n, len(text_sources))
99+
100+
body_dict = {
101+
"query": query,
102+
"documents": docs
104103
}
105-
response = bedrock_runtime.rerank(
106-
queries=queries, sources=text_sources, rerankingConfiguration=rerankingConfiguration
104+
105+
# Only add api_version for Cohere models
106+
if "cohere" in model_id.lower():
107+
body_dict["api_version"] = 2
108+
109+
body = json.dumps(body_dict)
110+
111+
response = bedrock_runtime.invoke_model(
112+
modelId=model_package_arn,
113+
body=body
107114
)
108115

116+
body_content = json.loads(response['body'].read())
117+
118+
results_to_process = body_content['results'][:numberOfResults]
119+
109120
rerank_documents = []
110-
for idx, result in enumerate(response["results"]):
121+
for idx, result in enumerate(results_to_process):
111122
# format document
112123
index = result["index"]
113124
rerank_document = RerankDocument(
114125
index=index,
115126
text=docs[index],
116-
score=result["relevanceScore"],
127+
score=result["relevance_score"],
117128
)
118129

119130
# score threshold check

0 commit comments

Comments
 (0)