Skip to content

Commit e882f00

Browse files
committed
Add EvaluateMethods a general method for benchmarking analogue search and exact match searches
1 parent fa4afdd commit e882f00

File tree

1 file changed

+205
-0
lines changed

1 file changed

+205
-0
lines changed
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import random
2+
3+
import numpy as np
4+
from typing import Callable, Tuple, List
5+
6+
from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix
7+
from tqdm import tqdm
8+
9+
from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectrumSetBase
10+
11+
12+
class EvaluateMethods:
13+
def __init__(
14+
self, training_spectrum_set: SpectraWithFingerprints, validation_spectrum_set: SpectraWithFingerprints
15+
):
16+
self.training_spectrum_set = training_spectrum_set
17+
self.validation_spectrum_set = validation_spectrum_set
18+
19+
self.training_spectrum_set.progress_bars = False
20+
self.validation_spectrum_set.progress_bars = False
21+
22+
def benchmark_analogue_search(
23+
self,
24+
prediction_function: Callable[
25+
[SpectraWithFingerprints, SpectraWithFingerprints], Tuple[List[str], List[float]]
26+
],
27+
) -> float:
28+
predicted_inchikeys, _ = prediction_function(self.training_spectrum_set, self.validation_spectrum_set)
29+
average_scores_per_inchikey = []
30+
31+
# Calculate score per unique inchikey
32+
for inchikey in tqdm(
33+
self.validation_spectrum_set.spectrum_indexes_per_inchikey.keys(),
34+
desc="Calculating analogue accuracy per inchikey",
35+
):
36+
matching_spectrum_indexes = self.validation_spectrum_set.spectrum_indexes_per_inchikey[inchikey]
37+
prediction_scores = []
38+
for index in matching_spectrum_indexes:
39+
predicted_inchikey = predicted_inchikeys[index]
40+
if predicted_inchikey is None:
41+
prediction_scores.append(0.0)
42+
else:
43+
predicted_fingerprint = self.training_spectrum_set.inchikey_fingerprint_pairs[predicted_inchikey]
44+
actual_fingerprint = self.validation_spectrum_set.inchikey_fingerprint_pairs[inchikey]
45+
tanimoto_for_prediction = calculate_tanimoto_score_between_pair(
46+
predicted_fingerprint, actual_fingerprint
47+
)
48+
prediction_scores.append(tanimoto_for_prediction)
49+
50+
average_prediction = sum(prediction_scores) / len(prediction_scores)
51+
score = average_prediction
52+
average_scores_per_inchikey.append(score)
53+
average_over_all_inchikeys = sum(average_scores_per_inchikey) / len(average_scores_per_inchikey)
54+
return average_over_all_inchikeys
55+
56+
def benchmark_exact_matching_within_ionmode(
57+
self,
58+
prediction_function: Callable[
59+
[SpectraWithFingerprints, SpectraWithFingerprints], Tuple[List[str], List[float]]
60+
],
61+
ionmode: str,
62+
) -> float:
63+
"""Test the accuracy at retrieving exact matches from the library
64+
65+
For each inchikey with more than 1 spectrum the spectra are split in two sets. Half for each inchikey is added
66+
to the library (training set), for the other half predictions are made. Thereby there is always an exact match
67+
avaialable. Only the highest ranked prediction is considered correct if the correct inchikey is predicted. An accuracy per
68+
inchikey is calculated followed by calculating the average.
69+
"""
70+
selected_spectra = subset_spectra_on_ionmode(self.validation_spectrum_set, ionmode)
71+
72+
set_1, set_2 = split_spectrum_set_per_inchikeys(selected_spectra)
73+
74+
predicted_inchikeys = predict_between_two_sets(self.training_spectrum_set, set_1, set_2, prediction_function)
75+
76+
# add the spectra to set_1
77+
set_1.add_spectra(set_2)
78+
return calculate_average_exact_match_accuracy(set_1, predicted_inchikeys)
79+
80+
def exact_matches_across_ionization_modes(
81+
self,
82+
prediction_function: Callable[
83+
[SpectraWithFingerprints, SpectraWithFingerprints], Tuple[List[str], List[float]]
84+
],
85+
):
86+
"""Test the accuracy at retrieving exact matches from the library if only available in other ionisation mode
87+
88+
Each val spectrum is matched against the training set with the other val spectra of the same inchikey, but other
89+
ionisation mode added to the library.
90+
"""
91+
pos_set, neg_set = split_spectrum_set_per_inchikey_across_ionmodes(self.validation_spectrum_set)
92+
predicted_inchikeys = predict_between_two_sets(
93+
self.training_spectrum_set, pos_set, neg_set, prediction_function
94+
)
95+
# add the spectra to set_1
96+
pos_set.add_spectra(neg_set)
97+
return calculate_average_exact_match_accuracy(pos_set, predicted_inchikeys)
98+
99+
def get_accuracy_recall_curve(self):
100+
"""This method should test the recall accuracy balance.
101+
All of the used methods use a threshold which indicates quality of prediction.
102+
A method that can predict well when a prediction is accurate is beneficial.
103+
We need a method to test this.
104+
105+
One method is generating a recall accuracy curve. This could be done for both the analogue search predictions
106+
and the exact match predictions. By returning the predicted score for a match this method could create an
107+
accuracy recall plot.
108+
"""
109+
raise NotImplementedError
110+
111+
112+
def predict_between_two_sets(
113+
library: SpectrumSetBase, query_set_1: SpectrumSetBase, query_set_2: SpectrumSetBase, prediction_function
114+
):
115+
"""Makes predictions between query sets and the library, with the other query set added.
116+
117+
This is necessary for testing exact matching"""
118+
training_set_copy = library.copy()
119+
training_set_copy.add_spectra(query_set_2)
120+
predicted_inchikeys_1, _ = prediction_function(training_set_copy, query_set_1)
121+
122+
training_set_copy = library.copy()
123+
training_set_copy.add_spectra(query_set_1)
124+
predicted_inchikeys_2, _ = prediction_function(training_set_copy, query_set_2)
125+
126+
return predicted_inchikeys_1 + predicted_inchikeys_2
127+
128+
129+
def calculate_average_exact_match_accuracy(spectrum_set: SpectrumSetBase, predicted_inchikeys: List[str]):
130+
if len(spectrum_set.spectra) != len(predicted_inchikeys):
131+
raise ValueError("The number of spectra should be equal to the number of predicted inchikeys ")
132+
exact_match_accuracy_per_inchikey = []
133+
for inchikey in tqdm(
134+
spectrum_set.spectrum_indexes_per_inchikey.keys(), desc="Calculating exact match accuracy per inchikey"
135+
):
136+
val_spectrum_indexes_matching_inchikey = spectrum_set.spectrum_indexes_per_inchikey[inchikey]
137+
correctly_predicted = 0
138+
for selected_spectrum_idx in val_spectrum_indexes_matching_inchikey:
139+
if inchikey == predicted_inchikeys[selected_spectrum_idx]:
140+
correctly_predicted += 1
141+
exact_match_accuracy_per_inchikey.append(correctly_predicted / len(val_spectrum_indexes_matching_inchikey))
142+
return sum(exact_match_accuracy_per_inchikey) / len(exact_match_accuracy_per_inchikey)
143+
144+
145+
def split_spectrum_set_per_inchikeys(spectrum_set: SpectrumSetBase) -> Tuple[SpectrumSetBase, SpectrumSetBase]:
146+
"""Splits a spectrum set into two.
147+
For each inchikey with more than one spectrum the spectra are divided over the two sets"""
148+
indexes_set_1 = []
149+
indexes_set_2 = []
150+
for inchikey in tqdm(spectrum_set.spectrum_indexes_per_inchikey.keys(), desc="Splitting spectra per inchikey"):
151+
val_spectrum_indexes_matching_inchikey = spectrum_set.spectrum_indexes_per_inchikey[inchikey]
152+
if len(val_spectrum_indexes_matching_inchikey) == 1:
153+
# all single spectra are excluded from this test, since no exact match can be added to the library
154+
continue
155+
split_index = len(val_spectrum_indexes_matching_inchikey) // 2
156+
random.shuffle(val_spectrum_indexes_matching_inchikey)
157+
indexes_set_1.extend(val_spectrum_indexes_matching_inchikey[:split_index])
158+
indexes_set_2.extend(val_spectrum_indexes_matching_inchikey[split_index:])
159+
return spectrum_set.subset_spectra(indexes_set_1), spectrum_set.subset_spectra(indexes_set_2)
160+
161+
162+
def split_spectrum_set_per_inchikey_across_ionmodes(
163+
spectrum_set: SpectrumSetBase,
164+
) -> Tuple[SpectrumSetBase, SpectrumSetBase]:
165+
"""Splits a spectrum set in two sets on ionmode. Only uses spectra for inchikeys with at least 1 pos and 1 neg"""
166+
all_pos_indexes = []
167+
all_neg_indexes = []
168+
for inchikey in tqdm(
169+
spectrum_set.spectrum_indexes_per_inchikey.keys(),
170+
desc="Splitting spectra per inchikey across ionmodes",
171+
):
172+
val_spectrum_indexes_matching_inchikey = spectrum_set.spectrum_indexes_per_inchikey[inchikey]
173+
positive_val_spectrum_indexes_current_inchikey = []
174+
negative_val_spectrum_indexes_current_inchikey = []
175+
for spectrum_index in val_spectrum_indexes_matching_inchikey:
176+
ionmode = spectrum_set.spectra[spectrum_index].get("ionmode")
177+
if ionmode == "positive":
178+
positive_val_spectrum_indexes_current_inchikey.append(spectrum_index)
179+
elif ionmode == "negative":
180+
negative_val_spectrum_indexes_current_inchikey.append(spectrum_index)
181+
182+
if (
183+
len(positive_val_spectrum_indexes_current_inchikey) < 1
184+
or len(negative_val_spectrum_indexes_current_inchikey) < 1
185+
):
186+
continue
187+
else:
188+
all_pos_indexes.extend(positive_val_spectrum_indexes_current_inchikey)
189+
all_neg_indexes.extend(negative_val_spectrum_indexes_current_inchikey)
190+
191+
pos_val_spectra = spectrum_set.subset_spectra(all_pos_indexes)
192+
neg_val_spectra = spectrum_set.subset_spectra(all_neg_indexes)
193+
return pos_val_spectra, neg_val_spectra
194+
195+
196+
def subset_spectra_on_ionmode(spectrum_set: SpectrumSetBase, ionmode) -> SpectrumSetBase:
197+
spectrum_indexes_to_keep = []
198+
for i, spectrum in enumerate(spectrum_set.spectra):
199+
if spectrum.get("ionmode") == ionmode:
200+
spectrum_indexes_to_keep.append(i)
201+
return spectrum_set.subset_spectra(spectrum_indexes_to_keep)
202+
203+
204+
def calculate_tanimoto_score_between_pair(fingerprint_1: str, fingerprint_2: str) -> float:
205+
return jaccard_similarity_matrix(np.array([fingerprint_1]), np.array([fingerprint_2]))[0][0]

0 commit comments

Comments
 (0)