Skip to content

Commit 8d9a4b7

Browse files
committed
Change SpectrumSetBase to SpectrumSet
1 parent 48e9dea commit 8d9a4b7

3 files changed

Lines changed: 17 additions & 16 deletions

File tree

ms2query/benchmarking/EvaluateMethods.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from matchms.similarity.vector_similarity_functions import jaccard_similarity_matrix
55
from tqdm import tqdm
6-
from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectrumSetBase
6+
from ms2query.benchmarking.SpectrumDataSet import SpectraWithFingerprints, SpectrumSet
77

88

99
class EvaluateMethods:
@@ -107,7 +107,7 @@ def get_accuracy_recall_curve(self):
107107

108108

109109
def predict_between_two_sets(
110-
library: SpectrumSetBase, query_set_1: SpectrumSetBase, query_set_2: SpectrumSetBase, prediction_function
110+
library: SpectrumSet, query_set_1: SpectrumSet, query_set_2: SpectrumSet, prediction_function
111111
):
112112
"""Makes predictions between query sets and the library, with the other query set added.
113113
@@ -123,7 +123,7 @@ def predict_between_two_sets(
123123
return predicted_inchikeys_1 + predicted_inchikeys_2
124124

125125

126-
def calculate_average_exact_match_accuracy(spectrum_set: SpectrumSetBase, predicted_inchikeys: List[str]):
126+
def calculate_average_exact_match_accuracy(spectrum_set: SpectrumSet, predicted_inchikeys: List[str]):
127127
if len(spectrum_set.spectra) != len(predicted_inchikeys):
128128
raise ValueError("The number of spectra should be equal to the number of predicted inchikeys ")
129129
exact_match_accuracy_per_inchikey = []
@@ -139,7 +139,7 @@ def calculate_average_exact_match_accuracy(spectrum_set: SpectrumSetBase, predic
139139
return sum(exact_match_accuracy_per_inchikey) / len(exact_match_accuracy_per_inchikey)
140140

141141

142-
def split_spectrum_set_per_inchikeys(spectrum_set: SpectrumSetBase) -> Tuple[SpectrumSetBase, SpectrumSetBase]:
142+
def split_spectrum_set_per_inchikeys(spectrum_set: SpectrumSet) -> Tuple[SpectrumSet, SpectrumSet]:
143143
"""Splits a spectrum set into two.
144144
For each inchikey with more than one spectrum the spectra are divided over the two sets"""
145145
indexes_set_1 = []
@@ -157,8 +157,8 @@ def split_spectrum_set_per_inchikeys(spectrum_set: SpectrumSetBase) -> Tuple[Spe
157157

158158

159159
def split_spectrum_set_per_inchikey_across_ionmodes(
160-
spectrum_set: SpectrumSetBase,
161-
) -> Tuple[SpectrumSetBase, SpectrumSetBase]:
160+
spectrum_set: SpectrumSet,
161+
) -> Tuple[SpectrumSet, SpectrumSet]:
162162
"""Splits a spectrum set in two sets on ionmode. Only uses spectra for inchikeys with at least 1 pos and 1 neg"""
163163
all_pos_indexes = []
164164
all_neg_indexes = []
@@ -190,7 +190,7 @@ def split_spectrum_set_per_inchikey_across_ionmodes(
190190
return pos_val_spectra, neg_val_spectra
191191

192192

193-
def subset_spectra_on_ionmode(spectrum_set: SpectrumSetBase, ionmode) -> SpectrumSetBase:
193+
def subset_spectra_on_ionmode(spectrum_set: SpectrumSet, ionmode) -> SpectrumSet:
194194
spectrum_indexes_to_keep = []
195195
for i, spectrum in enumerate(spectrum_set.spectra):
196196
if spectrum.get("ionmode") == ionmode:

ms2query/benchmarking/SpectrumDataSet.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from tqdm import tqdm
99

1010

11-
class SpectrumSetBase:
11+
12+
class SpectrumSet:
1213
"""Stores a spectrum dataset making it easy and fast to split on molecules"""
1314

1415
def __init__(self, spectra: List[Spectrum], progress_bars=False):
@@ -36,10 +37,10 @@ def _add_spectra_and_group_per_inchikey(self, spectra: List[Spectrum]):
3637
]
3738
return updated_inchikeys
3839

39-
def add_spectra(self, new_spectra: "SpectrumSetBase"):
40+
def add_spectra(self, new_spectra: "SpectrumSet"):
4041
return self._add_spectra_and_group_per_inchikey(new_spectra.spectra)
4142

42-
def subset_spectra(self, spectrum_indexes) -> "SpectrumSetBase":
43+
def subset_spectra(self, spectrum_indexes) -> "SpectrumSet":
4344
"""Returns a new instance of a subset of the spectra"""
4445
new_instance = copy.copy(self)
4546
new_instance._spectra = []
@@ -65,7 +66,7 @@ def copy(self):
6566
return new_instance
6667

6768

68-
class SpectraWithFingerprints(SpectrumSetBase):
69+
class SpectraWithFingerprints(SpectrumSet):
6970
"""Stores a spectrum dataset making it easy and fast to split on molecules"""
7071

7172
def __init__(self, spectra: List[Spectrum], fingerprint_type="daylight", nbits=4096):

tests/test_SpectrumDataSet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,21 @@
33
from ms2query.benchmarking.SpectrumDataSet import (
44
SpectraWithFingerprints,
55
SpectraWithMS2DeepScoreEmbeddings,
6-
SpectrumSetBase,
6+
SpectrumSet,
77
)
88
from tests.conftest import create_test_spectra, get_inchikey_inchi_pairs, ms2deepscore_model
99

1010

1111
@pytest.mark.parametrize(
1212
"library",
1313
[
14-
SpectrumSetBase(create_test_spectra()),
14+
SpectrumSet(create_test_spectra()),
1515
SpectraWithFingerprints(create_test_spectra()),
1616
SpectraWithMS2DeepScoreEmbeddings(create_test_spectra(), ms2deepscore_model()),
1717
],
1818
)
1919
def test_spectrum_set_base(library):
20-
"""Test all base functionality of SpectrumSetBase is implemented correctly
20+
"""Test all base functionality of SpectrumSet is implemented correctly
2121
also for all classes inheriting from it"""
2222
# test correct init
2323
assert len(library.spectra) == 9
@@ -74,7 +74,7 @@ def test_spectra_with_fingerprints(library):
7474
(get_inchikey_inchi_pairs(3), 3), # Fully overlapping
7575
(get_inchikey_inchi_pairs(1), 3), # Fully overlapping (but not all)
7676
):
77-
spectra_to_add = SpectrumSetBase(create_test_spectra(2, inchikey_inchi_pairs=inchikey_inchi_pairs))
77+
spectra_to_add = SpectrumSet(create_test_spectra(2, inchikey_inchi_pairs=inchikey_inchi_pairs))
7878
new_copy = library.copy()
7979
new_copy.add_spectra(spectra_to_add)
8080
assert len(new_copy.inchikey_fingerprint_pairs) == expected_nr_of_inchikeys
@@ -125,6 +125,6 @@ def test_spectra_with_embeddings():
125125
for i, index in enumerate(subset_indexes):
126126
assert np.all(library.embeddings[index] == subset.embeddings[i])
127127

128-
# Check that subsetting on subset works. To make sure that a subset does not become of type SpectrumSetBase
128+
# Check that subsetting on subset works. To make sure that a subset does not become of type SpectrumSet
129129
subsetted_subset = subset.subset_spectra([0, 1])
130130
assert subsetted_subset.embeddings.shape == (2, 100)

0 commit comments

Comments
 (0)