Skip to content

Commit f559e68

Browse files
committed
refactor
1 parent f9ccd6d commit f559e68

File tree

5 files changed

+410
-387
lines changed

5 files changed

+410
-387
lines changed

ms2query/database/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .ann_index import ANNIndex
2-
from .compound_database import CompoundDatabase, SpecToCompoundMap, map_from_spectraldb_metadata
2+
from .compound_database import CompoundDatabase
33
from .database_utils import blob_to_array
4+
from .spec_to_compound_mapper import SpecToCompoundMap, map_from_spectraldb_metadata
45
from .spectra_merging import cluster_and_merge_to_sqlite, ensure_merged_tables
56
from .spectral_database import SpectralDatabase
67

ms2query/database/compound_database.py

Lines changed: 1 addition & 293 deletions
Original file line numberDiff line numberDiff line change
@@ -5,73 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
from ms2query.data_processing import compute_morgan_fingerprints, inchikey14_from_full
8-
9-
10-
# =========================
11-
# Utilities & placeholders
12-
# =========================
13-
14-
def encode_sparse_fp(bits: Optional[np.ndarray], counts: Optional[np.ndarray]) -> tuple[bytes, bytes]:
15-
"""Store bits as uint32 indices, counts as int32
16-
17-
Parameters
18-
----------
19-
bits : array-like of uint32 bit indices
20-
counts : array-like of int32 counts
21-
22-
Returns (bits_blob, counts_blob). Accepts None -> empty blobs."""
23-
if bits is None:
24-
b = b""
25-
else:
26-
arr = np.asarray(bits)
27-
if arr.dtype != np.uint32:
28-
arr = arr.astype(np.uint32, copy=False)
29-
b = arr.tobytes(order="C")
30-
if counts is None:
31-
c = b""
32-
else:
33-
arrc = np.asarray(counts)
34-
if arrc.dtype != np.int32 and arrc.dtype != np.uint32 and arrc.dtype != np.uint16 and arrc.dtype != np.uint8:
35-
arrc = arrc.astype(np.int32, copy=False)
36-
c = arrc.tobytes(order="C")
37-
return b, c
38-
39-
def decode_sparse_fp(bits_blob: bytes, counts_blob: bytes) -> tuple[np.ndarray, np.ndarray]:
40-
"""Inverse of encode_sparse_fp.
41-
42-
Parameters
43-
----------
44-
bits_blob : BLOB bytes of uint32 bit indices
45-
counts_blob : BLOB bytes of int32 counts
46-
47-
Returns (bits_uint32, counts_int32). Empty blobs -> empty arrays.
48-
"""
49-
bits = np.frombuffer(bits_blob, dtype=np.uint32).copy() if bits_blob else np.zeros(0, dtype=np.uint32)
50-
# Guess signedness: store as int32 by default
51-
counts = np.frombuffer(counts_blob, dtype=np.int32).copy() if counts_blob else np.zeros(0, dtype=np.int32)
52-
return bits, counts
53-
54-
def encode_dense_fp(vec: Optional[np.ndarray]) -> bytes:
55-
"""Encode a dense vector as float32 bytes. None -> empty blob."""
56-
if vec is None:
57-
return b""
58-
arr = np.asarray(vec)
59-
if arr.dtype != np.float32:
60-
arr = arr.astype(np.float32, copy=False)
61-
return arr.ravel().tobytes(order="C")
62-
63-
def decode_dense_fp(blob: bytes, dtype: str = "float32") -> np.ndarray:
64-
"""Decode dense vector from blob with the given dtype (default float32)."""
65-
if not blob:
66-
return np.zeros(0, dtype=np.float32 if dtype == "float32" else np.dtype(dtype))
67-
return np.frombuffer(blob, dtype=np.dtype(dtype)).copy()
68-
69-
#def decode_fp_blob(blob: bytes) -> np.ndarray:
70-
# """Decode fingerprint BLOB back to uint8 array.
71-
# Unknown length -> infer from blob size."""
72-
# if not blob:
73-
# return np.zeros(0, dtype=np.uint8)
74-
# return np.frombuffer(blob, dtype=np.uint8).copy()
8+
from ms2query.database.database_utils import decode_dense_fp, decode_sparse_fp, encode_dense_fp, encode_sparse_fp
759

7610

7711
# ==================================================
@@ -777,229 +711,3 @@ def _update_rows_sparse_with_counts(comp_ids: List[str], pairs: List[Tuple[np.nd
777711
""")["n"].iloc[0]
778712

779713
return stats
780-
781-
782-
# ==================================================
783-
# Mapping: spectrum <-> compound (spec_to_comp)
784-
# ==================================================
785-
786-
@dataclass
787-
class SpecToCompoundMap:
788-
"""Stores (spec_id -> comp_id) mappings in SQLite. Use the SAME DB file as SpectralDatabase for simplicity."""
789-
sqlite_path: str
790-
table: str = "spec_to_comp"
791-
compound_table: str = "compounds"
792-
_conn: sqlite3.Connection = field(init=False, repr=False)
793-
794-
def __post_init__(self):
795-
Path(self.sqlite_path).parent.mkdir(parents=True, exist_ok=True)
796-
self._conn = sqlite3.connect(self.sqlite_path)
797-
self._conn.row_factory = sqlite3.Row
798-
self._ensure_schema()
799-
800-
def close(self):
801-
try:
802-
self._conn.close()
803-
except Exception:
804-
pass
805-
806-
def _ensure_schema(self):
807-
cur = self._conn.cursor()
808-
# No strict FK enforcement (SpectralDatabase may have been created without FK pragma),
809-
# here: index both sides for fast lookup.
810-
cur.executescript(f"""
811-
CREATE TABLE IF NOT EXISTS {self.table}(
812-
spec_id INTEGER NOT NULL,
813-
comp_id TEXT NOT NULL,
814-
PRIMARY KEY (spec_id),
815-
CHECK (length(comp_id) = 14)
816-
);
817-
CREATE INDEX IF NOT EXISTS idx_spec_to_comp_comp ON {self.table}(comp_id);
818-
""")
819-
self._conn.commit()
820-
821-
# ---------- API ----------
822-
823-
def link(self, spec_id: int, comp_id: str):
824-
"""Insert or replace a single mapping."""
825-
if not comp_id or len(comp_id) != 14:
826-
raise ValueError("comp_id must be inchikey14 (14 characters).")
827-
self._conn.execute(f"""
828-
INSERT INTO {self.table} (spec_id, comp_id)
829-
VALUES (?, ?)
830-
ON CONFLICT(spec_id) DO UPDATE SET comp_id=excluded.comp_id
831-
""", (spec_id, comp_id))
832-
self._conn.commit()
833-
834-
def link_many(self, pairs: Iterable[Tuple[int, str]]):
835-
"""Bulk link (spec_id, comp_id)."""
836-
cur = self._conn.cursor()
837-
cur.execute("BEGIN")
838-
try:
839-
cur.executemany(f"""
840-
INSERT INTO {self.table} (spec_id, comp_id)
841-
VALUES (?, ?)
842-
ON CONFLICT(spec_id) DO UPDATE SET comp_id=excluded.comp_id
843-
""", list(pairs))
844-
cur.execute("COMMIT")
845-
except Exception:
846-
cur.execute("ROLLBACK")
847-
raise
848-
849-
def get_comp_id_for_specs(self, spec_ids: List[int]) -> pd.DataFrame:
850-
"""Return a DataFrame with columns [spec_id, comp_id] for the provided spec_ids."""
851-
if not spec_ids:
852-
return pd.DataFrame(columns=["spec_id", "comp_id"])
853-
placeholders = ",".join("?" * len(spec_ids))
854-
rows = self._conn.execute(
855-
f"SELECT spec_id, comp_id FROM {self.table} WHERE spec_id IN ({placeholders})",
856-
spec_ids
857-
).fetchall()
858-
return pd.DataFrame(rows, columns=["spec_id", "comp_id"])
859-
860-
def get_specs_for_comp(self, comp_id: str) -> List[int]:
861-
"""Return list of spec_ids for a given comp_id."""
862-
rows = self._conn.execute(f"SELECT spec_id FROM {self.table} WHERE comp_id = ?", (comp_id,)).fetchall()
863-
return [r[0] for r in rows]
864-
865-
def get_all_mappings(self) -> pd.DataFrame:
866-
"""Return all spec_id <-> comp_id mappings as a DataFrame."""
867-
rows = self._conn.execute(f"SELECT spec_id, comp_id FROM {self.table}").fetchall()
868-
return pd.DataFrame(rows, columns=["spec_id", "comp_id"])
869-
870-
871-
# ==================================================
872-
# Integrations with SpectralDatabase
873-
# ==================================================
874-
875-
def map_from_spectraldb_metadata(
876-
spectral_db_sqlite_path: str,
877-
mapping_sqlite_path: Optional[str] = None,
878-
compounds_sqlite_path: Optional[str] = None,
879-
spectra_table: str = "spectra",
880-
compound_table: str = "compounds",
881-
mapping_table: str = "spec_to_comp",
882-
*,
883-
create_missing_compounds: bool = True
884-
) -> Tuple[int, int]:
885-
"""
886-
Read spectra metadata (expects 'inchikey' in metadata), create comp_id (inchikey14),
887-
populate spec_to_comp, and optionally upsert minimal compounds.
888-
889-
Returns: (n_mapped, n_new_compounds)
890-
"""
891-
# We do not import the class to avoid circular imports; use sqlite directly.
892-
s_conn = sqlite3.connect(spectral_db_sqlite_path)
893-
s_conn.row_factory = sqlite3.Row
894-
895-
map_db_path = mapping_sqlite_path or spectral_db_sqlite_path
896-
c_db_path = compounds_sqlite_path or spectral_db_sqlite_path
897-
898-
mapper = SpecToCompoundMap(map_db_path, table=mapping_table)
899-
compdb = CompoundDatabase(c_db_path, table=compound_table)
900-
901-
# Discover which columns exist in the spectra table
902-
cols = {r[1] for r in s_conn.execute(f"PRAGMA table_info({spectra_table})").fetchall()}
903-
want = ["spec_id", "inchikey", "smiles", "inchi", "classyfire_class", "classyfire_superclass"]
904-
have = [c for c in want if c in cols]
905-
select_cols = ", ".join(have)
906-
907-
rows = s_conn.execute(f"SELECT {select_cols} FROM {spectra_table}").fetchall()
908-
909-
to_link: List[Tuple[int, str]] = []
910-
new_comp_rows: List[Dict[str, Any]] = []
911-
912-
for r in rows:
913-
r = dict(r)
914-
spec_id = int(r["spec_id"])
915-
ik_full = r.get("inchikey")
916-
if not ik_full:
917-
continue
918-
comp_id = inchikey14_from_full(ik_full)
919-
if not comp_id:
920-
continue
921-
to_link.append((spec_id, comp_id))
922-
923-
if create_missing_compounds:
924-
new_comp_rows.append({
925-
"smiles": r.get("smiles"),
926-
"inchi": r.get("inchi"),
927-
"inchikey": ik_full,
928-
"classyfire_class": r.get("classyfire_class"),
929-
"classyfire_superclass": r.get("classyfire_superclass"),
930-
"fingerprint": None, # backfill later
931-
})
932-
933-
# Bulk linking
934-
if to_link:
935-
mapper.link_many(to_link)
936-
937-
# Upsert compounds
938-
n_new_compounds = 0
939-
if create_missing_compounds and new_comp_rows:
940-
# Deduplicate by comp_id to avoid redundant upserts
941-
seen: set[str] = set()
942-
dedup_rows: List[Dict[str, Any]] = []
943-
for r in new_comp_rows:
944-
cid = inchikey14_from_full(r["inchikey"])
945-
if cid and cid not in seen:
946-
seen.add(cid)
947-
dedup_rows.append(r)
948-
before = compdb.sql_query(f"SELECT COUNT(*) AS n FROM {compound_table}")["n"].iloc[0]
949-
compdb.upsert_many(dedup_rows)
950-
after = compdb.sql_query(f"SELECT COUNT(*) AS n FROM {compound_table}")["n"].iloc[0]
951-
n_new_compounds = int(after - before)
952-
953-
n_mapped = len(to_link)
954-
955-
# Close connections
956-
mapper.close()
957-
compdb.close()
958-
s_conn.close()
959-
960-
return n_mapped, n_new_compounds
961-
962-
963-
def get_unique_compounds_from_spectraldb(
964-
spectral_db_sqlite_path: str,
965-
spectra_table: str = "spectra",
966-
external_meta: Optional[pd.DataFrame] = None,
967-
external_key_col: str = "inchikey14"
968-
) -> pd.DataFrame:
969-
"""
970-
Return a DataFrame of unique compounds present in the spectral DB, inferred via inchikey → inchikey14.
971-
Columns: inchikey14, inchikey (full), n_spectra. If `external_meta` is provided,
972-
it will be left-joined on `external_key_col` (default 'inchikey14').
973-
"""
974-
conn = sqlite3.connect(spectral_db_sqlite_path)
975-
conn.row_factory = sqlite3.Row
976-
977-
# pull spec_id + inchikey from spectra
978-
df = pd.read_sql_query(f"SELECT spec_id, inchikey FROM {spectra_table}", conn)
979-
conn.close()
980-
981-
if df.empty:
982-
base = pd.DataFrame(columns=["inchikey14", "inchikey", "n_spectra"])
983-
if external_meta is not None:
984-
return base.merge(external_meta, how="left", left_on="inchikey14", right_on=external_key_col)
985-
return base
986-
987-
# Compute inchikey14
988-
ik14 = df["inchikey"].fillna("").map(inchikey14_from_full)
989-
df["inchikey14"] = ik14
990-
991-
# Aggregate
992-
agg = (
993-
df.dropna(subset=["inchikey14"])
994-
.groupby(["inchikey14"], as_index=False)
995-
.agg(n_spectra=("spec_id", "count"),
996-
inchikey=("inchikey", "first")) # first full key seen
997-
)
998-
999-
# Optional join with external meta
1000-
if external_meta is not None and not external_meta.empty:
1001-
agg = agg.merge(external_meta, how="left", left_on="inchikey14", right_on=external_key_col)
1002-
1003-
# Order by prevalence
1004-
agg = agg.sort_values("n_spectra", ascending=False).reset_index(drop=True)
1005-
return agg

0 commit comments

Comments
 (0)