Skip to content

Commit 048fefc

Browse files
authored
Merge branch 'main' into fix-macos-installation
2 parents 61ccb0f + b182977 commit 048fefc

File tree

10 files changed

+208
-10
lines changed

10 files changed

+208
-10
lines changed

.github/workflows/docker-image.yml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,15 @@ jobs:
9999
context: .
100100
file: ./Dockerfile
101101
push: true
102+
# FIXED: Explicitly disable GPU here to ensure 'latest' is the CPU version
103+
build-args: |
104+
GPU=false
102105
tags: ${{ steps.meta-cpu.outputs.tags }}
103106
labels: ${{ steps.meta-cpu.outputs.labels }}
104107
platforms: linux/amd64,linux/arm64
105-
cache-from: type=gha
106-
cache-to: type=gha,mode=max
108+
# FIXED: Added scope 'build-cpu' to isolate this cache
109+
cache-from: type=gha,scope=build-cpu
110+
cache-to: type=gha,mode=max,scope=build-cpu
107111

108112
# Step 7: Show the digest of the built CPU image (useful for traceability)
109113
- name: Show CPU image digest
@@ -185,8 +189,9 @@ jobs:
185189
tags: ${{ steps.meta-gpu.outputs.tags }}
186190
labels: ${{ steps.meta-gpu.outputs.labels }}
187191
platforms: linux/amd64,linux/arm64
188-
cache-from: type=gha
189-
cache-to: type=gha,mode=max
192+
# FIXED: Added scope 'build-gpu' to isolate this cache
193+
cache-from: type=gha,scope=build-gpu
194+
cache-to: type=gha,mode=max,scope=build-gpu
190195

191196
# Step 7: Show the digest of the built GPU image (useful for traceability)
192197
- name: Show GPU image digest

build-docker.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ get_interactive_choices() {
177177
if [[ -z "$VERSION" ]]; then error "Version tag cannot be empty."; fi
178178
fi
179179

180-
msg "A 'latest' tag (e.g., 'latest-gpu') acts as a pointer to the newest stable release."
181-
read -p "Apply 'latest' tags for this version? (Y/n): " push_latest_choice
180+
msg "A 'latest' tag (e.g., 'latest-cpu') acts as a pointer to the newest stable release."
181+
read -p "Apply 'latest' tag for this version? (Y/n): " push_latest_choice
182182
if [[ ! "$push_latest_choice" =~ ^[nN]$ ]]; then PUSH_LATEST=true; fi
183183

184184
read -p "Select action (1: Build locally, 2: Build and Push, 3: Cancel) [2]: " action_choice

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies = [
1818
"aiosqlite>=0.20.0",
1919
"boto3>=1.40.40",
2020
"boto3-stubs==1.40.64",
21+
"cohere>=5.20.0",
2122
"coverage>=7.11.0",
2223
"dotenv>=0.9.9",
2324
"fastapi>=0.116.1",

sample_configs/episodic_memory_config.cpu.sample

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ resources:
100100
provider: "identity"
101101
bm_ranker_id:
102102
provider: "bm25"
103+
cohere_reranker_id:
104+
provider: "cohere"
105+
config:
106+
cohere_key: <COHERE_API_KEY>
107+
model: "rerank-english-v3.0"
103108
aws_reranker_id:
104109
provider: "amazon-bedrock"
105110
config:

sample_configs/episodic_memory_config.gpu.sample

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ resources:
101101
provider: "identity"
102102
bm_ranker_id:
103103
provider: "bm25"
104+
cohere_reranker_id:
105+
provider: "cohere"
106+
config:
107+
cohere_key: <COHERE_API_KEY>
108+
model: "rerank-english-v3.0"
104109
ce_ranker_id:
105110
provider: "cross-encoder"
106111
config:

src/memmachine/common/configuration/reranker_conf.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,19 @@ class AmazonBedrockRerankerConf(YamlSerializableMixin):
5050
)
5151

5252

53+
class CohereRerankerConf(YamlSerializableMixin):
54+
"""Parameters for CohereReranker."""
55+
56+
cohere_key: SecretStr | None = Field(
57+
...,
58+
description="Cohere API key for authentication.",
59+
)
60+
model: str = Field(
61+
default="rerank-english-v3.0",
62+
description="Cohere rerank model",
63+
)
64+
65+
5366
class CrossEncoderRerankerConf(YamlSerializableMixin):
5467
"""Parameters for CrossEncoderReranker."""
5568

@@ -89,6 +102,7 @@ class RerankersConf(BaseModel):
89102

90103
bm25: dict[str, BM25RerankerConf] = {}
91104
amazon_bedrock: dict[str, AmazonBedrockRerankerConf] = {}
105+
cohere: dict[str, CohereRerankerConf] = {}
92106
cross_encoder: dict[str, CrossEncoderRerankerConf] = {}
93107
embedder: dict[str, EmbedderRerankerConf] = {}
94108
identity: dict[str, IdentityRerankerConf] = {}
@@ -101,6 +115,7 @@ def contains_reranker(self, reranker_id: str) -> bool:
101115
return reranker_id in self._saved_reranker_ids
102116

103117
BM25: ClassVar[str] = "bm25"
118+
COHERE: ClassVar[str] = "cohere"
104119
CROSS_ENCODER: ClassVar[str] = "cross-encoder"
105120
EMBEDDER: ClassVar[str] = "embedder"
106121
IDENTITY: ClassVar[str] = "identity"
@@ -126,6 +141,9 @@ def add_reranker(name: str, provider: str, config: dict) -> None:
126141
for reranker_id, conf in self.amazon_bedrock.items():
127142
add_reranker(reranker_id, self.AMAZON_BEDROCK, conf.to_yaml_dict())
128143

144+
for reranker_id, conf in self.cohere.items():
145+
add_reranker(reranker_id, self.COHERE, conf.to_yaml_dict())
146+
129147
for reranker_id, conf in self.cross_encoder.items():
130148
add_reranker(reranker_id, self.CROSS_ENCODER, conf.to_yaml_dict())
131149

@@ -154,6 +172,7 @@ def parse(cls, input_dict: dict) -> Self:
154172

155173
bm25_dict = {}
156174
amazon_bedrock_dict = {}
175+
cohere_dict = {}
157176
cross_encoder_dict = {}
158177
embedder_dict = {}
159178
identity_dict = {}
@@ -167,6 +186,8 @@ def parse(cls, input_dict: dict) -> Self:
167186
bm25_dict[reranker_id] = BM25RerankerConf(**conf)
168187
elif provider == cls.AMAZON_BEDROCK:
169188
amazon_bedrock_dict[reranker_id] = AmazonBedrockRerankerConf(**conf)
189+
elif provider == cls.COHERE:
190+
cohere_dict[reranker_id] = CohereRerankerConf(**conf)
170191
elif provider == cls.CROSS_ENCODER:
171192
cross_encoder_dict[reranker_id] = CrossEncoderRerankerConf(**conf)
172193
elif provider == cls.EMBEDDER:
@@ -183,6 +204,7 @@ def parse(cls, input_dict: dict) -> Self:
183204
ret = cls(
184205
bm25=bm25_dict,
185206
amazon_bedrock=amazon_bedrock_dict,
207+
cohere=cohere_dict,
186208
cross_encoder=cross_encoder_dict,
187209
embedder=embedder_dict,
188210
identity=identity_dict,
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Cohere reranker implementation."""
2+
3+
import asyncio
4+
import logging
5+
from typing import Any
6+
7+
import cohere
8+
from pydantic import BaseModel, Field
9+
10+
from memmachine.common.data_types import ExternalServiceAPIError
11+
12+
from .reranker import Reranker
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class CohereRerankerParams(BaseModel):
18+
"""Configuration parameters for CohereReranker."""
19+
20+
client: Any = Field(
21+
...,
22+
description="Cohere client instance for making API calls",
23+
)
24+
model: str = Field(
25+
"rerank-english-v3.0",
26+
description="Cohere rerank model",
27+
)
28+
29+
30+
class CohereReranker(Reranker):
31+
"""Reranker using Cohere's rerank API."""
32+
33+
def __init__(self, params: CohereRerankerParams) -> None:
34+
"""Initialize a CohereReranker with the provided parameters."""
35+
super().__init__()
36+
37+
self._client = params.client
38+
self._model = params.model
39+
40+
async def score(self, query: str, candidates: list[str]) -> list[float]:
41+
"""Score candidates using Cohere's rerank API."""
42+
43+
# Build request parameters
44+
def _call_rerank() -> cohere.RerankResponse:
45+
return self._client.rerank(
46+
model=self._model,
47+
query=query,
48+
documents=candidates,
49+
)
50+
51+
try:
52+
response = await asyncio.to_thread(_call_rerank)
53+
except Exception as e:
54+
error_message = (
55+
f"Failed to score candidates with Cohere model {self._model} "
56+
f"due to {type(e).__name__}: {e}"
57+
)
58+
logger.exception(error_message)
59+
raise ExternalServiceAPIError(error_message) from e
60+
61+
# Cohere returns ranked order — map scores back to original positions
62+
scores = [0.0] * len(candidates)
63+
for result in response.results:
64+
scores[result.index] = result.relevance_score
65+
66+
return scores

src/memmachine/common/resource_manager/reranker_manager.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ async def build_all(self) -> dict[str, Reranker]:
4747
name
4848
for keys in [
4949
self.conf.bm25.keys(),
50+
self.conf.cohere.keys(),
5051
self.conf.cross_encoder.keys(),
5152
self.conf.amazon_bedrock.keys(),
5253
self.conf.embedder.keys(),
@@ -80,6 +81,8 @@ async def _build_reranker(self, name: str) -> Reranker:
8081
"""Create a reranker based on provider-specific configuration."""
8182
if name in self.conf.bm25:
8283
return await self._build_bm25_reranker(name)
84+
if name in self.conf.cohere:
85+
return await self._build_cohere_reranker(name)
8386
if name in self.conf.cross_encoder:
8487
return await self._build_cross_encoder_reranker(name)
8588
if name in self.conf.amazon_bedrock:
@@ -129,6 +132,25 @@ def _default_tokenize(text: str) -> list[str]:
129132
)
130133
return self.rerankers[name]
131134

135+
async def _build_cohere_reranker(self, name: str) -> Reranker:
136+
from cohere import ClientV2
137+
138+
from memmachine.common.reranker.cohere_reranker import (
139+
CohereReranker,
140+
CohereRerankerParams,
141+
)
142+
143+
conf = self.conf.cohere[name]
144+
145+
cohere_api_key = conf.cohere_key.get_secret_value() if conf.cohere_key else None
146+
client = ClientV2(api_key=cohere_api_key)
147+
params = CohereRerankerParams(
148+
client=client,
149+
model=conf.model,
150+
)
151+
self.rerankers[name] = CohereReranker(params)
152+
return self.rerankers[name]
153+
132154
async def _build_cross_encoder_reranker(self, name: str) -> Reranker:
133155
from sentence_transformers import CrossEncoder
134156

tests/memmachine/common/resource_manager/test_reranker_manager.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from memmachine.common.configuration.reranker_conf import (
77
AmazonBedrockRerankerConf,
88
BM25RerankerConf,
9+
CohereRerankerConf,
910
CrossEncoderRerankerConf,
1011
RerankersConf,
1112
RRFHybridRerankerConf,
@@ -23,7 +24,13 @@ def mock_conf():
2324
),
2425
},
2526
identity={"id_ranker_id": {}},
26-
bm25={"bm_ranker_id": BM25RerankerConf(tokenize="simple")},
27+
bm25={"bm_ranker_id": BM25RerankerConf(tokenizer="simple")},
28+
cohere={
29+
"cohere_reranker_id": CohereRerankerConf(
30+
cohere_key=SecretStr("<COHERE_API_KEY>"),
31+
model="rerank-english-v3.0",
32+
),
33+
},
2734
cross_encoder={
2835
"ce_ranker_id": CrossEncoderRerankerConf(
2936
model_name="cross-encoder/qnli-electra-base",
@@ -79,6 +86,15 @@ async def test_reranker_not_found(reranker_manager):
7986
await reranker_manager.get_reranker("unknown_reranker_id")
8087

8188

89+
@pytest.mark.asyncio
90+
async def test_build_cohere_rerankers(reranker_manager):
91+
await reranker_manager.build_all()
92+
93+
assert "cohere_reranker_id" in reranker_manager.rerankers
94+
reranker = reranker_manager.rerankers["cohere_reranker_id"]
95+
assert reranker is not None
96+
97+
8298
@pytest.mark.asyncio
8399
async def test_build_cross_encoder_rerankers(reranker_manager):
84100
await reranker_manager.build_all()

0 commit comments

Comments
 (0)