@@ -24,6 +24,7 @@ def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]:
24
24
@dataclasses .dataclass
25
25
class _Candidate :
26
26
id : str
27
+ similarity : float
27
28
weighted_similarity : float
28
29
weighted_redundancy : float
29
30
score : float = dataclasses .field (init = False )
@@ -69,6 +70,13 @@ class MmrHelper:
69
70
70
71
selected_ids : list [str ]
71
72
"""List of selected IDs (in selection order)."""
73
+
74
+ selected_mmr_scores : list [float ]
75
+ """List of MMR score at the time each document is selected."""
76
+
77
+ selected_similarity_scores : list [float ]
78
+ """List of similarity score for each selected document."""
79
+
72
80
selected_embeddings : NDArray [np .float32 ]
73
81
"""(N, dim) ndarray with a row for each selected node."""
74
82
@@ -100,6 +108,8 @@ def __init__(
100
108
self .score_threshold = score_threshold
101
109
102
110
self .selected_ids = []
111
+ self .selected_similarity_scores = []
112
+ self .selected_mmr_scores = []
103
113
104
114
# List of selected embeddings (in selection order).
105
115
self .selected_embeddings = np .ndarray ((k , self .dimensions ), dtype = np .float32 )
@@ -123,11 +133,11 @@ def _already_selected_embeddings(self) -> NDArray[np.float32]:
123
133
selected = len (self .selected_ids )
124
134
return np .vsplit (self .selected_embeddings , [selected ])[0 ]
125
135
126
- def _pop_candidate (self , candidate_id : str ) -> NDArray [np .float32 ]:
136
+ def _pop_candidate (self , candidate_id : str ) -> tuple [ float , NDArray [np .float32 ] ]:
127
137
"""Pop the candidate with the given ID.
128
138
129
139
Returns:
130
- The embedding of the candidate.
140
+ The similarity score and embedding of the candidate.
131
141
"""
132
142
# Get the embedding for the id.
133
143
index = self .candidate_id_to_index .pop (candidate_id )
@@ -143,12 +153,15 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
143
153
# candidate_embeddings.
144
154
last_index = self .candidate_embeddings .shape [0 ] - 1
145
155
156
+ similarity = 0.0
146
157
if index == last_index :
147
158
# Already the last item. We don't need to swap.
148
- self .candidates .pop ()
159
+ similarity = self .candidates .pop (). similarity
149
160
else :
150
161
self .candidate_embeddings [index ] = self .candidate_embeddings [last_index ]
151
162
163
+ similarity = self .candidates [index ].similarity
164
+
152
165
old_last = self .candidates .pop ()
153
166
self .candidates [index ] = old_last
154
167
self .candidate_id_to_index [old_last .id ] = index
@@ -157,7 +170,7 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
157
170
0
158
171
]
159
172
160
- return embedding
173
+ return similarity , embedding
161
174
162
175
def pop_best (self ) -> str | None :
163
176
"""Select and pop the best item being considered.
@@ -172,11 +185,13 @@ def pop_best(self) -> str | None:
172
185
173
186
# Get the selection and remove from candidates.
174
187
selected_id = self .best_id
175
- selected_embedding = self ._pop_candidate (selected_id )
188
+ selected_similarity , selected_embedding = self ._pop_candidate (selected_id )
176
189
177
190
# Add the ID and embedding to the selected information.
178
191
selection_index = len (self .selected_ids )
179
192
self .selected_ids .append (selected_id )
193
+ self .selected_mmr_scores .append (self .best_score )
194
+ self .selected_similarity_scores .append (selected_similarity )
180
195
self .selected_embeddings [selection_index ] = selected_embedding
181
196
182
197
# Reset the best score / best ID.
@@ -232,6 +247,7 @@ def add_candidates(self, candidates: dict[str, list[float]]) -> None:
232
247
max_redundancy = redundancy [index ].max ()
233
248
candidate = _Candidate (
234
249
id = candidate_id ,
250
+ similarity = similarity [index ][0 ],
235
251
weighted_similarity = self .lambda_mult * similarity [index ][0 ],
236
252
weighted_redundancy = self .lambda_mult_complement * max_redundancy ,
237
253
)
0 commit comments