2222 confidence_scores ,
2323 convert_conv_history_to_query ,
2424 cos_sim ,
25+ energy_calc ,
26+ energy_distance ,
2527 download ,
2628 hole ,
2729 mrr ,
@@ -95,6 +97,19 @@ def __init__(
9597 # custom functions can be used by extending the DenseRetrievalExactSearch class
9698 self .predict = self .model .predict
9799
100+ def precompute_corpus_embeddings (self , corpus , model , batch_size , chunk_size ):
101+ all_corpus_embeddings = []
102+ print ("Length of corpus:" , len (corpus ))
103+ print ("Batch size:" , batch_size )
104+ for start_idx in range (0 , len (corpus ), chunk_size ):
105+ end_idx = min (start_idx + chunk_size , len (corpus ))
106+ print ("Chunk to document:" , end_idx )
107+ chunk = corpus [start_idx :end_idx ]
108+ embeddings = model .encode_corpus (chunk , batch_size = batch_size , convert_to_tensor = True )
109+ embeddings = embeddings .to ('cpu' )
110+ all_corpus_embeddings .append (embeddings )
111+ return all_corpus_embeddings
112+
98113 def search (
99114 self ,
100115 corpus : dict [str , dict [str , str ]],
@@ -120,8 +135,9 @@ def search(
120135 ** self .encode_kwargs ,
121136 )
122137 else :
123- query_embeddings = self .model .encode (
138+ query_embeddings , attention_masks = self .model .encode (
124139 queries , # type: ignore
140+ output_value = "token_embeddings" ,
125141 task_name = task_name ,
126142 prompt_type = PromptType .query ,
127143 ** self .encode_kwargs ,
@@ -135,83 +151,105 @@ def search(
135151 corpus = [corpus [cid ] for cid in corpus_ids ] # type: ignore
136152
137153 logger .info ("Encoding Corpus in batches... Warning: This might take a while!" )
154+ # Precompute all corpus embeddings
155+ all_corpus_embeddings = self .precompute_corpus_embeddings (
156+ corpus = corpus ,
157+ model = self .model , # Use the corpus-specific model
158+ batch_size = self .batch_size ,
159+ chunk_size = self .corpus_chunk_size
160+ )
138161
139162 itr = range (0 , len (corpus ), self .corpus_chunk_size )
140163
141164 result_heaps = {
142165 qid : [] for qid in query_ids
143166 } # Keep only the top-k docs for each query
144- for batch_num , corpus_start_idx in enumerate (itr ):
145- logger .info (f"Encoding Batch { batch_num + 1 } /{ len (itr )} ..." )
146- corpus_end_idx = min (corpus_start_idx + self .corpus_chunk_size , len (corpus ))
147-
148- # Encode chunk of corpus
149- if (
150- self .save_corpus_embeddings
151- and request_qid
152- and len (self .corpus_embeddings [request_qid ])
153- ):
154- sub_corpus_embeddings = torch .tensor (
155- self .corpus_embeddings [request_qid ][batch_num ]
156- )
157- else :
158- # Encode chunk of corpus
159- sub_corpus_embeddings = self .model .encode (
160- corpus [corpus_start_idx :corpus_end_idx ], # type: ignore
161- task_name = task_name ,
162- prompt_type = PromptType .passage ,
163- request_qid = request_qid ,
164- ** self .encode_kwargs ,
165- )
166- if self .save_corpus_embeddings and request_qid :
167- self .corpus_embeddings [request_qid ].append (sub_corpus_embeddings )
168167
169- # Compute similarites using self defined similarity otherwise default to cosine-similarity
170- if hasattr (self .model , "similarity" ):
171- similarity_scores = self .model .similarity (
172- query_embeddings , sub_corpus_embeddings
168+ for query_batch_index , query_batch in enumerate (query_embeddings ):
169+ for chunk_idx , sub_corpus_embeddings in enumerate (all_corpus_embeddings ):
170+ #for batch_num, corpus_start_idx in enumerate(itr):
171+ logger .info (f"Encoding Batch { batch_num + 1 } /{ len (itr )} ..." )
172+ #corpus_end_idx = min(corpus_start_idx + self.corpus_chunk_size, len(corpus))
173+ sub_corpus_embeddings = sub_corpus_embeddings .to ('cuda' )
174+ chunk_start_idx = chunk_idx * self .corpus_chunk_size # Calculate the starting index of this chunk
175+
176+
177+ # Encode chunk of corpus
178+ #if (
179+ # self.save_corpus_embeddings
180+ # and request_qid
181+ # and len(self.corpus_embeddings[request_qid])
182+ #):
183+ # sub_corpus_embeddings = torch.tensor(
184+ # self.corpus_embeddings[request_qid][batch_num]
185+ # )
186+ #else:
187+ # Encode chunk of corpus
188+ # sub_corpus_embeddings = self.model.encode(
189+ # corpus[corpus_start_idx:corpus_end_idx], # type: ignore
190+ # task_name=task_name,
191+ # prompt_type=PromptType.passage,
192+ # request_qid=request_qid,
193+ # **self.encode_kwargs,
194+ # )
195+ # if self.save_corpus_embeddings and request_qid:
196+ # self.corpus_embeddings[request_qid].append(sub_corpus_embeddings)
197+
198+ # Compute similarites using self defined similarity otherwise default to cosine-similarity
199+ #if hasattr(self.model, "similarity"):
200+ # similarity_scores = self.model.similarity(
201+ # query_embeddings, sub_corpus_embeddings
202+ # )
203+ #else:
204+ similarity_scores = energy_distance (query_embeddings , sub_corpus_embeddings , attention_masks [query_batch_index ])
205+ is_nan = torch .isnan (similarity_scores )
206+ if is_nan .sum () > 0 :
207+ logger .warning (
208+ f"Found { is_nan .sum ()} NaN values in the similarity scores. Replacing NaN values with -inf."
209+ )
210+ similarity_scores [is_nan ] = float ('inf' ) * - 1
211+
212+ # Get top-k values
213+ similarity_scores_top_k_values , similarity_scores_top_k_idx = torch .topk (
214+ similarity_scores ,
215+ min (
216+ top_k + 1 ,
217+ len (similarity_scores [1 ])
218+ if len (similarity_scores ) > 1
219+ else len (similarity_scores [- 1 ]),
220+ ),
221+ dim = 1 ,
222+ largest = True ,
223+ sorted = return_sorted ,
173224 )
174- else :
175- similarity_scores = cos_sim (query_embeddings , sub_corpus_embeddings )
176- is_nan = torch .isnan (similarity_scores )
177- if is_nan .sum () > 0 :
178- logger .warning (
179- f"Found { is_nan .sum ()} NaN values in the similarity scores. Replacing NaN values with -1."
225+ similarity_scores_top_k_values = (
226+ similarity_scores_top_k_values .cpu ().tolist ()
180227 )
181- similarity_scores [is_nan ] = - 1
182-
183- # Get top-k values
184- similarity_scores_top_k_values , similarity_scores_top_k_idx = torch .topk (
185- similarity_scores ,
186- min (
187- top_k + 1 ,
188- len (similarity_scores [1 ])
189- if len (similarity_scores ) > 1
190- else len (similarity_scores [- 1 ]),
191- ),
192- dim = 1 ,
193- largest = True ,
194- sorted = return_sorted ,
195- )
196- similarity_scores_top_k_values = (
197- similarity_scores_top_k_values .cpu ().tolist ()
198- )
199- similarity_scores_top_k_idx = similarity_scores_top_k_idx .cpu ().tolist ()
200-
201- for query_itr in range (len (query_embeddings )):
202- query_id = query_ids [query_itr ]
203- for sub_corpus_id , score in zip (
204- similarity_scores_top_k_idx [query_itr ],
205- similarity_scores_top_k_values [query_itr ],
206- ):
207- corpus_id = corpus_ids [corpus_start_idx + sub_corpus_id ]
208- if len (result_heaps [query_id ]) < top_k :
209- # Push item on the heap
210- heapq .heappush (result_heaps [query_id ], (score , corpus_id ))
211- else :
212- # If item is larger than the smallest in the heap, push it on the heap then pop the smallest element
213- heapq .heappushpop (result_heaps [query_id ], (score , corpus_id ))
214-
228+ similarity_scores_top_k_idx = similarity_scores_top_k_idx .cpu ().tolist ()
229+
230+ for query_itr in range (len (query_batch )):
231+ global_query_index = query_itr + (query_batch_index * self .batch_size )
232+ #query_id = query_ids[query_itr]
233+ query_id = query_ids [global_query_index ]
234+ for sub_corpus_id , score in zip (
235+ similarity_scores_top_k_idx [query_itr ],
236+ similarity_scores_top_k_values [query_itr ],
237+ ):
238+ #corpus_id = corpus_ids[corpus_start_idx + sub_corpus_id]
239+ corpus_id = corpus_ids [chunk_start_idx + sub_corpus_id ] # Use chunk_start_idx here
240+ if len (result_heaps [query_id ]) < top_k :
241+ # Push item on the heap
242+ heapq .heappush (result_heaps [query_id ], (score , corpus_id ))
243+ else :
244+ # If item is larger than the smallest in the heap, push it on the heap then pop the smallest element
245+ heapq .heappushpop (result_heaps [query_id ], (score , corpus_id ))
246+
247+ sub_corpus_embeddings = sub_corpus_embeddings .to ('cpu' ) #Move corpus chunk back to cpu because it will be reused
248+
249+ # After processing the batch, delete the query_batch tensor to free GPU memory
250+ del query_batch
251+ torch .cuda .empty_cache () # Optionally free up any cached GPU memory
252+
215253 for qid in result_heaps :
216254 for score , corpus_id in result_heaps [qid ]:
217255 self .results [qid ][corpus_id ] = score
@@ -351,7 +389,7 @@ def encode_conversations(
351389 )
352390 queries = self .convert_conv_history_to_query (model , conversations ) # type: ignore
353391 return model .encode (
354- queries , task_name = task_name , prompt_type = PromptType .query , ** kwargs
392+ queries , output_value = "token_embeddings" , task_name = task_name , prompt_type = PromptType .query , ** kwargs
355393 ) # type: ignore
356394
357395 @staticmethod
@@ -421,7 +459,7 @@ def encode(
421459 sentences , task_name , prompt_type = prompt_type , ** kwargs
422460 )
423461 return self .model .encode (
424- sentences , task_name = task_name , prompt_type = prompt_type , ** kwargs
462+ sentences , output_value = "token_embeddings" , task_name = task_name , prompt_type = prompt_type , ** kwargs
425463 )
426464
427465
0 commit comments