Skip to content

Commit 92b434c

Browse files
committed
add fingerprint index
1 parent b0b7748 commit 92b434c

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

ms2query/library_io.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tqdm import tqdm
99
from ms2query import MS2QueryDatabase, MS2QueryLibrary
1010
from ms2query.data_processing.merging_utils import cluster_block, get_merged_spectra
11-
from ms2query.database import EmbeddingIndex
11+
from ms2query.database import EmbeddingIndex, FingerprintSparseIndex
1212
from ms2query.database.spectra_merging import _split_by_mode_charge
1313

1414

@@ -18,6 +18,7 @@
1818
_SQLITE_NAME = "ms2query_library.sqlite"
1919
_EMB_TABLE = "embeddings"
2020
_EMB_INDEX_BASENAME = "embedding_index" # will create embedding_index.{nmslib,ids.npy,meta.json}
21+
_FP_INDEX_BASENAME = "fingerprint_index" # will create fingerprint_index.{nmslib,ids.npy,meta.json}
2122

2223

2324
def _handle_default_settings(settings: dict) -> dict:
@@ -81,6 +82,7 @@ def create_new_library(
8182
model_path: str,
8283
additional_compound_file: Optional[str] = None,
8384
build_embedding_index: bool = True,
85+
build_fingerprint_index: bool = True,
8486
embedding_index_params: Optional[dict] = None,
8587
compute_embeddings_batch_rows: int = 4096,
8688
**settings,
@@ -101,6 +103,8 @@ def create_new_library(
101103
CSV/TSV file with additional compounds (inchikey/smiles/etc.). No fingerprints assumed here.
102104
build_embedding_index : bool
103105
Whether to build the nmslib cosine HNSW index over embeddings.
106+
build_fingerprint_index : bool
107+
Whether to build the FingerprintSparseIndex over compound fingerprints.
104108
embedding_index_params : dict
105109
Params for HNSW: {'M': int, 'ef_construction': int, 'post_init_ef': int, 'batch_rows': int}
106110
compute_embeddings_batch_rows : int
@@ -139,6 +143,8 @@ def create_new_library(
139143
_print_progress(f"Inserted {creation_stats['n_inserted_spectra']} spectra.")
140144
_print_progress(f"Mapped {creation_stats['n_mapped']} spectra to compounds; "
141145
f"created {creation_stats['n_new_compounds']} new compounds.")
146+
stats = ms2query_db.ref_cdb.compute_fingerprints_missing()
147+
_print_progress(f"Computed fingerprints for {stats['updated']} compounds.")
142148

143149
if additional_compound_file is not None:
144150
if not additional_compound_file.lower().endswith((".csv", ".tsv", ".txt")):
@@ -199,6 +205,23 @@ def create_new_library(
199205
_print_progress(f"Saved EmbeddingIndex files with prefix: {emb_prefix}")
200206
lib.set_embedding_index(emb_index)
201207

208+
if build_fingerprint_index:
209+
# TODO: this is not efficient yet; improve later
210+
_print_progress("Building FingerprintSparseIndex ...")
211+
results = lib.db.ref_cdb.get_all_fingerprints_and_comp_ids()
212+
max_bits = [x[0][-1] for x in results["fingerprints"]]
213+
fp_index = FingerprintSparseIndex(dim=int(max(max_bits) + 1))
214+
215+
fp_index.build_index(
216+
results["fingerprints"],
217+
results["comp_ids"],
218+
)
219+
fp_prefix = str(out_dir / _FP_INDEX_BASENAME)
220+
fp_index.save_index(fp_prefix)
221+
_print_progress(f"Saved FingerprintSparseIndex files with prefix: {fp_prefix}")
222+
lib.set_fingerprint_index(fp_index)
223+
224+
202225
# -----------------------------
203226
# Manifest
204227
# -----------------------------
@@ -208,6 +231,7 @@ def create_new_library(
208231
"embedding_table": _EMB_TABLE,
209232
"model_path": model_path, # stored for convenience; not copied
210233
"embedding_index_prefix": _EMB_INDEX_BASENAME if build_embedding_index else None,
234+
"fingerprint_index_prefix": _FP_INDEX_BASENAME if build_fingerprint_index else None,
211235
"settings": settings,
212236
}
213237
with open(out_dir / _MANIFEST_NAME, "w", encoding="utf-8") as f:
@@ -262,6 +286,13 @@ def load_created_library(folder: str) -> MS2QueryLibrary:
262286
emb_index = EmbeddingIndex()
263287
emb_index.load_index(str(out_dir / emb_prefix))
264288
lib.set_embedding_index(emb_index)
289+
290+
# Load FingerprintSparseIndex if present
291+
fp_prefix = manifest.get("fingerprint_index_prefix")
292+
if fp_prefix:
293+
fp_index = FingerprintSparseIndex()
294+
fp_index.load_index(str(out_dir / fp_prefix))
295+
lib.set_fingerprint_index(fp_index)
265296

266297
# (Optional) Load fingerprint index here if/when you add it later.
267298

0 commit comments

Comments
 (0)