11from typing import Optional
22import logging
3+ import json
34
45from botocore .exceptions import ClientError
56
@@ -61,8 +62,7 @@ def _invoke(
6162 return RerankResult (model = model , docs = docs )
6263
6364 # initialize client
64- bedrock_runtime = get_bedrock_client ("bedrock-agent-runtime" , credentials )
65- queries = [{"type" : "TEXT" , "textQuery" : {"text" : query }}]
65+ bedrock_runtime = get_bedrock_client ("bedrock-runtime" , credentials )
6666 text_sources = []
6767 for text in docs :
6868 text_sources .append (
@@ -76,6 +76,7 @@ def _invoke(
7676 },
7777 }
7878 )
79+
7980 # Check if using inference profile
8081 model_id = model
8182 inference_profile_id = credentials .get ("inference_profile_id" )
@@ -93,27 +94,37 @@ def _invoke(
9394 if not region :
9495 raise InvokeBadRequestError ("aws_region is required in credentials" )
9596 model_package_arn = f"arn:aws:bedrock:{ region } ::foundation-model/{ model_id } "
96- rerankingConfiguration = {
97- "type" : "BEDROCK_RERANKING_MODEL" ,
98- "bedrockRerankingConfiguration" : {
99- "numberOfResults" : min (len (text_sources ) if top_n is None else top_n , len (text_sources )),
100- "modelConfiguration" : {
101- "modelArn" : model_package_arn ,
102- },
103- },
97+
98+ numberOfResults = min (len (text_sources ) if top_n is None else top_n , len (text_sources ))
99+
100+ body_dict = {
101+ "query" : query ,
102+ "documents" : docs
104103 }
105- response = bedrock_runtime .rerank (
106- queries = queries , sources = text_sources , rerankingConfiguration = rerankingConfiguration
104+
105+ # Only add api_version for Cohere models
106+ if "cohere" in model_id .lower ():
107+ body_dict ["api_version" ] = 2
108+
109+ body = json .dumps (body_dict )
110+
111+ response = bedrock_runtime .invoke_model (
112+ modelId = model_package_arn ,
113+ body = body
107114 )
108115
116+ body_content = json .loads (response ['body' ].read ())
117+
118+ results_to_process = body_content ['results' ][:numberOfResults ]
119+
109120 rerank_documents = []
110- for idx , result in enumerate (response [ "results" ] ):
121+ for idx , result in enumerate (results_to_process ):
111122 # format document
112123 index = result ["index" ]
113124 rerank_document = RerankDocument (
114125 index = index ,
115126 text = docs [index ],
116- score = result ["relevanceScore " ],
127+ score = result ["relevance_score " ],
117128 )
118129
119130 # score threshold check
0 commit comments