Skip to content

Commit 8c7af37

Browse files
committed
add highest predicted tanimoto score to results
1 parent 3d9a1d8 commit 8c7af37

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

ms2query/ms2query_development/ReferenceLibrary.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ def run_ms2query(
142142
num_of_query_embeddings = query_embeddings.embeddings.shape[0]
143143

144144
library_index_highest_ms2deepscore = np.zeros((num_of_query_embeddings), dtype=int)
145+
highest_ms2deepscore_values = np.zeros((num_of_query_embeddings), dtype=float)
145146
ms2query_scores = []
147+
146148
for start_idx in tqdm(
147149
range(0, num_of_query_embeddings, batch_size),
148150
desc="Predicting highest ms2deepscore per batch of "
@@ -153,20 +155,29 @@ def run_ms2query(
153155
end_idx = min(start_idx + batch_size, num_of_query_embeddings)
154156
selected_query_embeddings = query_embeddings.embeddings[start_idx:end_idx]
155157
score_matrix = cosine_similarity_matrix(selected_query_embeddings, self.reference_embeddings.embeddings)
158+
156159
highest_score_idx = np.argmax(score_matrix, axis=1)
160+
highest_score_values = np.max(score_matrix, axis=1)
161+
157162
library_index_highest_ms2deepscore[start_idx:end_idx] = highest_score_idx
163+
highest_ms2deepscore_values[start_idx:end_idx] = highest_score_values
158164

159165
# get predicted inchikeys
160166
predicted_inchikeys = self.reference_metadata.iloc[highest_score_idx]["inchikey"]
167+
161168
# Compute MS2Query reliability score
162169
ms2query_scores.extend(
163170
get_ms2query_reliability_prediction(
164-
predicted_inchikeys, self.spectrum_indices_per_inchikey, self.top_k_tanimoto_scores, score_matrix
171+
predicted_inchikeys,
172+
self.spectrum_indices_per_inchikey,
173+
self.top_k_tanimoto_scores,
174+
score_matrix,
165175
)
166176
)
167177

168178
# construct results df
169-
results = self.reference_metadata.iloc[library_index_highest_ms2deepscore]
179+
results = self.reference_metadata.iloc[library_index_highest_ms2deepscore].copy()
180+
results["predicted_tanimoto"] = highest_ms2deepscore_values
170181
results["ms2query_reliability_prediction"] = ms2query_scores
171182
return results
172183

0 commit comments

Comments
 (0)