Skip to content

Commit a8a092a

Browse files
authored
feat: add scores to MMR results (#652)
* feat: add scores to MMR results * add tests * adjust tolerance
1 parent bca3e4f commit a8a092a

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

libs/knowledge-store/ragstack_knowledge_store/_mmr_helper.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]:
2424
@dataclasses.dataclass
2525
class _Candidate:
2626
id: str
27+
similarity: float
2728
weighted_similarity: float
2829
weighted_redundancy: float
2930
score: float = dataclasses.field(init=False)
@@ -69,6 +70,13 @@ class MmrHelper:
6970

7071
selected_ids: list[str]
7172
"""List of selected IDs (in selection order)."""
73+
74+
selected_mmr_scores: list[float]
75+
"""List of MMR score at the time each document is selected."""
76+
77+
selected_similarity_scores: list[float]
78+
"""List of similarity score for each selected document."""
79+
7280
selected_embeddings: NDArray[np.float32]
7381
"""(N, dim) ndarray with a row for each selected node."""
7482

@@ -100,6 +108,8 @@ def __init__(
100108
self.score_threshold = score_threshold
101109

102110
self.selected_ids = []
111+
self.selected_similarity_scores = []
112+
self.selected_mmr_scores = []
103113

104114
# List of selected embeddings (in selection order).
105115
self.selected_embeddings = np.ndarray((k, self.dimensions), dtype=np.float32)
@@ -123,11 +133,11 @@ def _already_selected_embeddings(self) -> NDArray[np.float32]:
123133
selected = len(self.selected_ids)
124134
return np.vsplit(self.selected_embeddings, [selected])[0]
125135

126-
def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
136+
def _pop_candidate(self, candidate_id: str) -> tuple[float, NDArray[np.float32]]:
127137
"""Pop the candidate with the given ID.
128138
129139
Returns:
130-
The embedding of the candidate.
140+
The similarity score and embedding of the candidate.
131141
"""
132142
# Get the embedding for the id.
133143
index = self.candidate_id_to_index.pop(candidate_id)
@@ -143,12 +153,15 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
143153
# candidate_embeddings.
144154
last_index = self.candidate_embeddings.shape[0] - 1
145155

156+
similarity = 0.0
146157
if index == last_index:
147158
# Already the last item. We don't need to swap.
148-
self.candidates.pop()
159+
similarity = self.candidates.pop().similarity
149160
else:
150161
self.candidate_embeddings[index] = self.candidate_embeddings[last_index]
151162

163+
similarity = self.candidates[index].similarity
164+
152165
old_last = self.candidates.pop()
153166
self.candidates[index] = old_last
154167
self.candidate_id_to_index[old_last.id] = index
@@ -157,7 +170,7 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
157170
0
158171
]
159172

160-
return embedding
173+
return similarity, embedding
161174

162175
def pop_best(self) -> str | None:
163176
"""Select and pop the best item being considered.
@@ -172,11 +185,13 @@ def pop_best(self) -> str | None:
172185

173186
# Get the selection and remove from candidates.
174187
selected_id = self.best_id
175-
selected_embedding = self._pop_candidate(selected_id)
188+
selected_similarity, selected_embedding = self._pop_candidate(selected_id)
176189

177190
# Add the ID and embedding to the selected information.
178191
selection_index = len(self.selected_ids)
179192
self.selected_ids.append(selected_id)
193+
self.selected_mmr_scores.append(self.best_score)
194+
self.selected_similarity_scores.append(selected_similarity)
180195
self.selected_embeddings[selection_index] = selected_embedding
181196

182197
# Reset the best score / best ID.
@@ -232,6 +247,7 @@ def add_candidates(self, candidates: dict[str, list[float]]) -> None:
232247
max_redundancy = redundancy[index].max()
233248
candidate = _Candidate(
234249
id=candidate_id,
250+
similarity=similarity[index][0],
235251
weighted_similarity=self.lambda_mult * similarity[index][0],
236252
weighted_redundancy=self.lambda_mult_complement * max_redundancy,
237253
)

libs/knowledge-store/ragstack_knowledge_store/graph_store.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,13 @@ def fetch_initial_candidates() -> None:
549549
depths[adjacent.target_content_id] = next_depth
550550
helper.add_candidates(new_candidates)
551551

552-
return self._nodes_with_ids(helper.selected_ids)
552+
nodes = self._nodes_with_ids(helper.selected_ids)
553+
for node, similarity_score, mmr_score in zip(
554+
nodes, helper.selected_similarity_scores, helper.selected_mmr_scores
555+
):
556+
node.metadata["similarity_score"] = similarity_score
557+
node.metadata["mmr_score"] = mmr_score
558+
return nodes
553559

554560
def traversal_search(
555561
self,

libs/knowledge-store/tests/unit_tests/test_mmr_helper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,8 @@ def test_mmr_helper_added_documetns() -> None:
7373
}
7474
)
7575
assert helper.pop_best() == "v2"
76+
77+
assert math.isclose(helper.selected_similarity_scores[0], 0.9251, abs_tol=0.0001)
78+
assert math.isclose(helper.selected_similarity_scores[1], 0.7071, abs_tol=0.0001)
79+
assert math.isclose(helper.selected_mmr_scores[0], 0.4625, abs_tol=0.0001)
80+
assert math.isclose(helper.selected_mmr_scores[1], 0.1608, abs_tol=0.0001)

0 commit comments

Comments
 (0)