@@ -50,6 +50,19 @@ class AmazonBedrockRerankerConf(YamlSerializableMixin):
5050 )
5151
5252
53+ class CohereRerankerConf (YamlSerializableMixin ):
54+ """Parameters for CohereReranker."""
55+
56+ cohere_key : SecretStr | None = Field (
57+ ...,
58+ description = "Cohere API key for authentication." ,
59+ )
60+ model : str = Field (
61+ default = "rerank-english-v3.0" ,
62+ description = "Cohere rerank model" ,
63+ )
64+
65+
5366class CrossEncoderRerankerConf (YamlSerializableMixin ):
5467 """Parameters for CrossEncoderReranker."""
5568
@@ -89,6 +102,7 @@ class RerankersConf(BaseModel):
89102
90103 bm25 : dict [str , BM25RerankerConf ] = {}
91104 amazon_bedrock : dict [str , AmazonBedrockRerankerConf ] = {}
105+ cohere : dict [str , CohereRerankerConf ] = {}
92106 cross_encoder : dict [str , CrossEncoderRerankerConf ] = {}
93107 embedder : dict [str , EmbedderRerankerConf ] = {}
94108 identity : dict [str , IdentityRerankerConf ] = {}
@@ -101,6 +115,7 @@ def contains_reranker(self, reranker_id: str) -> bool:
101115 return reranker_id in self ._saved_reranker_ids
102116
103117 BM25 : ClassVar [str ] = "bm25"
118+ COHERE : ClassVar [str ] = "cohere"
104119 CROSS_ENCODER : ClassVar [str ] = "cross-encoder"
105120 EMBEDDER : ClassVar [str ] = "embedder"
106121 IDENTITY : ClassVar [str ] = "identity"
@@ -126,6 +141,9 @@ def add_reranker(name: str, provider: str, config: dict) -> None:
126141 for reranker_id , conf in self .amazon_bedrock .items ():
127142 add_reranker (reranker_id , self .AMAZON_BEDROCK , conf .to_yaml_dict ())
128143
144+ for reranker_id , conf in self .cohere .items ():
145+ add_reranker (reranker_id , self .COHERE , conf .to_yaml_dict ())
146+
129147 for reranker_id , conf in self .cross_encoder .items ():
130148 add_reranker (reranker_id , self .CROSS_ENCODER , conf .to_yaml_dict ())
131149
@@ -154,6 +172,7 @@ def parse(cls, input_dict: dict) -> Self:
154172
155173 bm25_dict = {}
156174 amazon_bedrock_dict = {}
175+ cohere_dict = {}
157176 cross_encoder_dict = {}
158177 embedder_dict = {}
159178 identity_dict = {}
@@ -167,6 +186,8 @@ def parse(cls, input_dict: dict) -> Self:
167186 bm25_dict [reranker_id ] = BM25RerankerConf (** conf )
168187 elif provider == cls .AMAZON_BEDROCK :
169188 amazon_bedrock_dict [reranker_id ] = AmazonBedrockRerankerConf (** conf )
189+ elif provider == cls .COHERE :
190+ cohere_dict [reranker_id ] = CohereRerankerConf (** conf )
170191 elif provider == cls .CROSS_ENCODER :
171192 cross_encoder_dict [reranker_id ] = CrossEncoderRerankerConf (** conf )
172193 elif provider == cls .EMBEDDER :
@@ -183,6 +204,7 @@ def parse(cls, input_dict: dict) -> Self:
183204 ret = cls (
184205 bm25 = bm25_dict ,
185206 amazon_bedrock = amazon_bedrock_dict ,
207+ cohere = cohere_dict ,
186208 cross_encoder = cross_encoder_dict ,
187209 embedder = embedder_dict ,
188210 identity = identity_dict ,
0 commit comments