@@ -142,7 +142,9 @@ def run_ms2query(
142142 num_of_query_embeddings = query_embeddings .embeddings .shape [0 ]
143143
144144 library_index_highest_ms2deepscore = np .zeros ((num_of_query_embeddings ), dtype = int )
145+ highest_ms2deepscore_values = np .zeros ((num_of_query_embeddings ), dtype = float )
145146 ms2query_scores = []
147+
146148 for start_idx in tqdm (
147149 range (0 , num_of_query_embeddings , batch_size ),
148150 desc = "Predicting highest ms2deepscore per batch of "
@@ -153,20 +155,29 @@ def run_ms2query(
153155 end_idx = min (start_idx + batch_size , num_of_query_embeddings )
154156 selected_query_embeddings = query_embeddings .embeddings [start_idx :end_idx ]
155157 score_matrix = cosine_similarity_matrix (selected_query_embeddings , self .reference_embeddings .embeddings )
158+
156159 highest_score_idx = np .argmax (score_matrix , axis = 1 )
160+ highest_score_values = np .max (score_matrix , axis = 1 )
161+
157162 library_index_highest_ms2deepscore [start_idx :end_idx ] = highest_score_idx
163+ highest_ms2deepscore_values [start_idx :end_idx ] = highest_score_values
158164
159165 # get predicted inchikeys
160166 predicted_inchikeys = self .reference_metadata .iloc [highest_score_idx ]["inchikey" ]
167+
161168 # Compute MS2Query reliability score
162169 ms2query_scores .extend (
163170 get_ms2query_reliability_prediction (
164- predicted_inchikeys , self .spectrum_indices_per_inchikey , self .top_k_tanimoto_scores , score_matrix
171+ predicted_inchikeys ,
172+ self .spectrum_indices_per_inchikey ,
173+ self .top_k_tanimoto_scores ,
174+ score_matrix ,
165175 )
166176 )
167177
168178 # construct results df
169- results = self .reference_metadata .iloc [library_index_highest_ms2deepscore ]
179+ results = self .reference_metadata .iloc [library_index_highest_ms2deepscore ].copy ()
180+ results ["predicted_tanimoto" ] = highest_ms2deepscore_values
170181 results ["ms2query_reliability_prediction" ] = ms2query_scores
171182 return results
172183
0 commit comments