Skip to content

Commit 4e4f751

Browse files
authored
Modified RetrievalEvaluator.py to evaluate on ED
-change output_value for encode_queries to token_embeddings, and return the attention_mask (2d tensor) for the batch of queries, query_embeddings is now a 3d tensor with padded embeddings since the encode() function in sentence-transformers has been modified -add precompute_corpus_embeddings function to return compute corpus embeddings before iterating through query batches -change loop so that we iterate through batches of query embeddings along with the chunks of precomputed corpus embeddings -move corpus embedding chunk in the loop to the GPU and calculate the starting index of the chunk -modified score function to be energy distance with function arguments to ED function being, query_embeddings (3d tensor), corpus_embeddings (2d tensor), and attention_masks (2d tensor) -replaced similarity scores with NaN value to negative infinity (since in our case higher ED is larger similarity because we flipped sign of ED calculation) -when using the minHeap to store the top-k documents for each query, use the global query index and global corpus index to represent a query and corpus respectively (since our loop iterated through batches for the queries and corpus) -after calculating scores for the query batch and corpus batch combination and updating the minHeap, move the corpus embeddings batch back to cpu because they will be reused, and delete the query embeddings batch -modified encode() call in encode_conversations() and encode() to add output_value = “token_embeddings” as an argument
1 parent 07b73f7 commit 4e4f751

File tree

1 file changed

+109
-71
lines changed

1 file changed

+109
-71
lines changed

mteb/evaluation/evaluators/RetrievalEvaluator.py

Lines changed: 109 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
confidence_scores,
2323
convert_conv_history_to_query,
2424
cos_sim,
25+
energy_calc,
26+
energy_distance,
2527
download,
2628
hole,
2729
mrr,
@@ -95,6 +97,19 @@ def __init__(
9597
# custom functions can be used by extending the DenseRetrievalExactSearch class
9698
self.predict = self.model.predict
9799

100+
def precompute_corpus_embeddings(self, corpus, model, batch_size, chunk_size):
101+
all_corpus_embeddings = []
102+
print("Length of corpus:", len(corpus))
103+
print("Batch size:", batch_size)
104+
for start_idx in range(0, len(corpus), chunk_size):
105+
end_idx = min(start_idx + chunk_size, len(corpus))
106+
print("Chunk to document:", end_idx)
107+
chunk = corpus[start_idx:end_idx]
108+
embeddings = model.encode_corpus(chunk, batch_size=batch_size, convert_to_tensor=True)
109+
embeddings = embeddings.to('cpu')
110+
all_corpus_embeddings.append(embeddings)
111+
return all_corpus_embeddings
112+
98113
def search(
99114
self,
100115
corpus: dict[str, dict[str, str]],
@@ -120,8 +135,9 @@ def search(
120135
**self.encode_kwargs,
121136
)
122137
else:
123-
query_embeddings = self.model.encode(
138+
query_embeddings, attention_masks = self.model.encode(
124139
queries, # type: ignore
140+
output_value="token_embeddings",
125141
task_name=task_name,
126142
prompt_type=PromptType.query,
127143
**self.encode_kwargs,
@@ -135,83 +151,105 @@ def search(
135151
corpus = [corpus[cid] for cid in corpus_ids] # type: ignore
136152

137153
logger.info("Encoding Corpus in batches... Warning: This might take a while!")
154+
# Precompute all corpus embeddings
155+
all_corpus_embeddings = self.precompute_corpus_embeddings(
156+
corpus=corpus,
157+
model=self.model, # Use the corpus-specific model
158+
batch_size=self.batch_size,
159+
chunk_size=self.corpus_chunk_size
160+
)
138161

139162
itr = range(0, len(corpus), self.corpus_chunk_size)
140163

141164
result_heaps = {
142165
qid: [] for qid in query_ids
143166
} # Keep only the top-k docs for each query
144-
for batch_num, corpus_start_idx in enumerate(itr):
145-
logger.info(f"Encoding Batch {batch_num + 1}/{len(itr)}...")
146-
corpus_end_idx = min(corpus_start_idx + self.corpus_chunk_size, len(corpus))
147-
148-
# Encode chunk of corpus
149-
if (
150-
self.save_corpus_embeddings
151-
and request_qid
152-
and len(self.corpus_embeddings[request_qid])
153-
):
154-
sub_corpus_embeddings = torch.tensor(
155-
self.corpus_embeddings[request_qid][batch_num]
156-
)
157-
else:
158-
# Encode chunk of corpus
159-
sub_corpus_embeddings = self.model.encode(
160-
corpus[corpus_start_idx:corpus_end_idx], # type: ignore
161-
task_name=task_name,
162-
prompt_type=PromptType.passage,
163-
request_qid=request_qid,
164-
**self.encode_kwargs,
165-
)
166-
if self.save_corpus_embeddings and request_qid:
167-
self.corpus_embeddings[request_qid].append(sub_corpus_embeddings)
168167

169-
# Compute similarites using self defined similarity otherwise default to cosine-similarity
170-
if hasattr(self.model, "similarity"):
171-
similarity_scores = self.model.similarity(
172-
query_embeddings, sub_corpus_embeddings
168+
for query_batch_index, query_batch in enumerate(query_embeddings):
169+
for chunk_idx, sub_corpus_embeddings in enumerate(all_corpus_embeddings):
170+
#for batch_num, corpus_start_idx in enumerate(itr):
171+
logger.info(f"Encoding Batch {batch_num + 1}/{len(itr)}...")
172+
#corpus_end_idx = min(corpus_start_idx + self.corpus_chunk_size, len(corpus))
173+
sub_corpus_embeddings = sub_corpus_embeddings.to('cuda')
174+
chunk_start_idx = chunk_idx * self.corpus_chunk_size # Calculate the starting index of this chunk
175+
176+
177+
# Encode chunk of corpus
178+
#if (
179+
# self.save_corpus_embeddings
180+
# and request_qid
181+
# and len(self.corpus_embeddings[request_qid])
182+
#):
183+
# sub_corpus_embeddings = torch.tensor(
184+
# self.corpus_embeddings[request_qid][batch_num]
185+
# )
186+
#else:
187+
# Encode chunk of corpus
188+
# sub_corpus_embeddings = self.model.encode(
189+
# corpus[corpus_start_idx:corpus_end_idx], # type: ignore
190+
# task_name=task_name,
191+
# prompt_type=PromptType.passage,
192+
# request_qid=request_qid,
193+
# **self.encode_kwargs,
194+
# )
195+
# if self.save_corpus_embeddings and request_qid:
196+
# self.corpus_embeddings[request_qid].append(sub_corpus_embeddings)
197+
198+
# Compute similarites using self defined similarity otherwise default to cosine-similarity
199+
#if hasattr(self.model, "similarity"):
200+
# similarity_scores = self.model.similarity(
201+
# query_embeddings, sub_corpus_embeddings
202+
# )
203+
#else:
204+
similarity_scores = energy_distance(query_embeddings, sub_corpus_embeddings, attention_masks[query_batch_index])
205+
is_nan = torch.isnan(similarity_scores)
206+
if is_nan.sum() > 0:
207+
logger.warning(
208+
f"Found {is_nan.sum()} NaN values in the similarity scores. Replacing NaN values with -inf."
209+
)
210+
similarity_scores[is_nan] = float('inf') * -1
211+
212+
# Get top-k values
213+
similarity_scores_top_k_values, similarity_scores_top_k_idx = torch.topk(
214+
similarity_scores,
215+
min(
216+
top_k + 1,
217+
len(similarity_scores[1])
218+
if len(similarity_scores) > 1
219+
else len(similarity_scores[-1]),
220+
),
221+
dim=1,
222+
largest=True,
223+
sorted=return_sorted,
173224
)
174-
else:
175-
similarity_scores = cos_sim(query_embeddings, sub_corpus_embeddings)
176-
is_nan = torch.isnan(similarity_scores)
177-
if is_nan.sum() > 0:
178-
logger.warning(
179-
f"Found {is_nan.sum()} NaN values in the similarity scores. Replacing NaN values with -1."
225+
similarity_scores_top_k_values = (
226+
similarity_scores_top_k_values.cpu().tolist()
180227
)
181-
similarity_scores[is_nan] = -1
182-
183-
# Get top-k values
184-
similarity_scores_top_k_values, similarity_scores_top_k_idx = torch.topk(
185-
similarity_scores,
186-
min(
187-
top_k + 1,
188-
len(similarity_scores[1])
189-
if len(similarity_scores) > 1
190-
else len(similarity_scores[-1]),
191-
),
192-
dim=1,
193-
largest=True,
194-
sorted=return_sorted,
195-
)
196-
similarity_scores_top_k_values = (
197-
similarity_scores_top_k_values.cpu().tolist()
198-
)
199-
similarity_scores_top_k_idx = similarity_scores_top_k_idx.cpu().tolist()
200-
201-
for query_itr in range(len(query_embeddings)):
202-
query_id = query_ids[query_itr]
203-
for sub_corpus_id, score in zip(
204-
similarity_scores_top_k_idx[query_itr],
205-
similarity_scores_top_k_values[query_itr],
206-
):
207-
corpus_id = corpus_ids[corpus_start_idx + sub_corpus_id]
208-
if len(result_heaps[query_id]) < top_k:
209-
# Push item on the heap
210-
heapq.heappush(result_heaps[query_id], (score, corpus_id))
211-
else:
212-
# If item is larger than the smallest in the heap, push it on the heap then pop the smallest element
213-
heapq.heappushpop(result_heaps[query_id], (score, corpus_id))
214-
228+
similarity_scores_top_k_idx = similarity_scores_top_k_idx.cpu().tolist()
229+
230+
for query_itr in range(len(query_batch)):
231+
global_query_index = query_itr + (query_batch_index * self.batch_size)
232+
#query_id = query_ids[query_itr]
233+
query_id = query_ids[global_query_index]
234+
for sub_corpus_id, score in zip(
235+
similarity_scores_top_k_idx[query_itr],
236+
similarity_scores_top_k_values[query_itr],
237+
):
238+
#corpus_id = corpus_ids[corpus_start_idx + sub_corpus_id]
239+
corpus_id = corpus_ids[chunk_start_idx + sub_corpus_id] # Use chunk_start_idx here
240+
if len(result_heaps[query_id]) < top_k:
241+
# Push item on the heap
242+
heapq.heappush(result_heaps[query_id], (score, corpus_id))
243+
else:
244+
# If item is larger than the smallest in the heap, push it on the heap then pop the smallest element
245+
heapq.heappushpop(result_heaps[query_id], (score, corpus_id))
246+
247+
sub_corpus_embeddings = sub_corpus_embeddings.to('cpu') #Move corpus chunk back to cpu because it will be reused
248+
249+
# After processing the batch, delete the query_batch tensor to free GPU memory
250+
del query_batch
251+
torch.cuda.empty_cache() # Optionally free up any cached GPU memory
252+
215253
for qid in result_heaps:
216254
for score, corpus_id in result_heaps[qid]:
217255
self.results[qid][corpus_id] = score
@@ -351,7 +389,7 @@ def encode_conversations(
351389
)
352390
queries = self.convert_conv_history_to_query(model, conversations) # type: ignore
353391
return model.encode(
354-
queries, task_name=task_name, prompt_type=PromptType.query, **kwargs
392+
queries, output_value="token_embeddings", task_name=task_name, prompt_type=PromptType.query, **kwargs
355393
) # type: ignore
356394

357395
@staticmethod
@@ -421,7 +459,7 @@ def encode(
421459
sentences, task_name, prompt_type=prompt_type, **kwargs
422460
)
423461
return self.model.encode(
424-
sentences, task_name=task_name, prompt_type=prompt_type, **kwargs
462+
sentences, output_value="token_embeddings", task_name=task_name, prompt_type=prompt_type, **kwargs
425463
)
426464

427465

0 commit comments

Comments
 (0)