|
5 | 5 | import numpy as np |
6 | 6 | import pandas as pd |
7 | 7 | from ms2query.data_processing import compute_morgan_fingerprints, inchikey14_from_full |
8 | | - |
9 | | - |
10 | | -# ========================= |
11 | | -# Utilities & placeholders |
12 | | -# ========================= |
13 | | - |
14 | | -def encode_sparse_fp(bits: Optional[np.ndarray], counts: Optional[np.ndarray]) -> tuple[bytes, bytes]: |
15 | | - """Store bits as uint32 indices, counts as int32 |
16 | | -
|
17 | | - Parameters |
18 | | - ---------- |
19 | | - bits : array-like of uint32 bit indices |
20 | | - counts : array-like of int32 counts |
21 | | -
|
22 | | - Returns (bits_blob, counts_blob). Accepts None -> empty blobs.""" |
23 | | - if bits is None: |
24 | | - b = b"" |
25 | | - else: |
26 | | - arr = np.asarray(bits) |
27 | | - if arr.dtype != np.uint32: |
28 | | - arr = arr.astype(np.uint32, copy=False) |
29 | | - b = arr.tobytes(order="C") |
30 | | - if counts is None: |
31 | | - c = b"" |
32 | | - else: |
33 | | - arrc = np.asarray(counts) |
34 | | - if arrc.dtype != np.int32 and arrc.dtype != np.uint32 and arrc.dtype != np.uint16 and arrc.dtype != np.uint8: |
35 | | - arrc = arrc.astype(np.int32, copy=False) |
36 | | - c = arrc.tobytes(order="C") |
37 | | - return b, c |
38 | | - |
39 | | -def decode_sparse_fp(bits_blob: bytes, counts_blob: bytes) -> tuple[np.ndarray, np.ndarray]: |
40 | | - """Inverse of encode_sparse_fp. |
41 | | -
|
42 | | - Parameters |
43 | | - ---------- |
44 | | - bits_blob : BLOB bytes of uint32 bit indices |
45 | | - counts_blob : BLOB bytes of int32 counts |
46 | | -
|
47 | | - Returns (bits_uint32, counts_int32). Empty blobs -> empty arrays. |
48 | | - """ |
49 | | - bits = np.frombuffer(bits_blob, dtype=np.uint32).copy() if bits_blob else np.zeros(0, dtype=np.uint32) |
50 | | - # Guess signedness: store as int32 by default |
51 | | - counts = np.frombuffer(counts_blob, dtype=np.int32).copy() if counts_blob else np.zeros(0, dtype=np.int32) |
52 | | - return bits, counts |
53 | | - |
54 | | -def encode_dense_fp(vec: Optional[np.ndarray]) -> bytes: |
55 | | - """Encode a dense vector as float32 bytes. None -> empty blob.""" |
56 | | - if vec is None: |
57 | | - return b"" |
58 | | - arr = np.asarray(vec) |
59 | | - if arr.dtype != np.float32: |
60 | | - arr = arr.astype(np.float32, copy=False) |
61 | | - return arr.ravel().tobytes(order="C") |
62 | | - |
63 | | -def decode_dense_fp(blob: bytes, dtype: str = "float32") -> np.ndarray: |
64 | | - """Decode dense vector from blob with the given dtype (default float32).""" |
65 | | - if not blob: |
66 | | - return np.zeros(0, dtype=np.float32 if dtype == "float32" else np.dtype(dtype)) |
67 | | - return np.frombuffer(blob, dtype=np.dtype(dtype)).copy() |
68 | | - |
69 | | -#def decode_fp_blob(blob: bytes) -> np.ndarray: |
70 | | -# """Decode fingerprint BLOB back to uint8 array. |
71 | | -# Unknown length -> infer from blob size.""" |
72 | | -# if not blob: |
73 | | -# return np.zeros(0, dtype=np.uint8) |
74 | | -# return np.frombuffer(blob, dtype=np.uint8).copy() |
| 8 | +from ms2query.database.database_utils import decode_dense_fp, decode_sparse_fp, encode_dense_fp, encode_sparse_fp |
75 | 9 |
|
76 | 10 |
|
77 | 11 | # ================================================== |
@@ -777,229 +711,3 @@ def _update_rows_sparse_with_counts(comp_ids: List[str], pairs: List[Tuple[np.nd |
777 | 711 | """)["n"].iloc[0] |
778 | 712 |
|
779 | 713 | return stats |
780 | | - |
781 | | - |
782 | | -# ================================================== |
783 | | -# Mapping: spectrum <-> compound (spec_to_comp) |
784 | | -# ================================================== |
785 | | - |
786 | | -@dataclass |
787 | | -class SpecToCompoundMap: |
788 | | - """Stores (spec_id -> comp_id) mappings in SQLite. Use the SAME DB file as SpectralDatabase for simplicity.""" |
789 | | - sqlite_path: str |
790 | | - table: str = "spec_to_comp" |
791 | | - compound_table: str = "compounds" |
792 | | - _conn: sqlite3.Connection = field(init=False, repr=False) |
793 | | - |
794 | | - def __post_init__(self): |
795 | | - Path(self.sqlite_path).parent.mkdir(parents=True, exist_ok=True) |
796 | | - self._conn = sqlite3.connect(self.sqlite_path) |
797 | | - self._conn.row_factory = sqlite3.Row |
798 | | - self._ensure_schema() |
799 | | - |
800 | | - def close(self): |
801 | | - try: |
802 | | - self._conn.close() |
803 | | - except Exception: |
804 | | - pass |
805 | | - |
806 | | - def _ensure_schema(self): |
807 | | - cur = self._conn.cursor() |
808 | | - # No strict FK enforcement (SpectralDatabase may have been created without FK pragma), |
809 | | - # here: index both sides for fast lookup. |
810 | | - cur.executescript(f""" |
811 | | - CREATE TABLE IF NOT EXISTS {self.table}( |
812 | | - spec_id INTEGER NOT NULL, |
813 | | - comp_id TEXT NOT NULL, |
814 | | - PRIMARY KEY (spec_id), |
815 | | - CHECK (length(comp_id) = 14) |
816 | | - ); |
817 | | - CREATE INDEX IF NOT EXISTS idx_spec_to_comp_comp ON {self.table}(comp_id); |
818 | | - """) |
819 | | - self._conn.commit() |
820 | | - |
821 | | - # ---------- API ---------- |
822 | | - |
823 | | - def link(self, spec_id: int, comp_id: str): |
824 | | - """Insert or replace a single mapping.""" |
825 | | - if not comp_id or len(comp_id) != 14: |
826 | | - raise ValueError("comp_id must be inchikey14 (14 characters).") |
827 | | - self._conn.execute(f""" |
828 | | - INSERT INTO {self.table} (spec_id, comp_id) |
829 | | - VALUES (?, ?) |
830 | | - ON CONFLICT(spec_id) DO UPDATE SET comp_id=excluded.comp_id |
831 | | - """, (spec_id, comp_id)) |
832 | | - self._conn.commit() |
833 | | - |
834 | | - def link_many(self, pairs: Iterable[Tuple[int, str]]): |
835 | | - """Bulk link (spec_id, comp_id).""" |
836 | | - cur = self._conn.cursor() |
837 | | - cur.execute("BEGIN") |
838 | | - try: |
839 | | - cur.executemany(f""" |
840 | | - INSERT INTO {self.table} (spec_id, comp_id) |
841 | | - VALUES (?, ?) |
842 | | - ON CONFLICT(spec_id) DO UPDATE SET comp_id=excluded.comp_id |
843 | | - """, list(pairs)) |
844 | | - cur.execute("COMMIT") |
845 | | - except Exception: |
846 | | - cur.execute("ROLLBACK") |
847 | | - raise |
848 | | - |
849 | | - def get_comp_id_for_specs(self, spec_ids: List[int]) -> pd.DataFrame: |
850 | | - """Return a DataFrame with columns [spec_id, comp_id] for the provided spec_ids.""" |
851 | | - if not spec_ids: |
852 | | - return pd.DataFrame(columns=["spec_id", "comp_id"]) |
853 | | - placeholders = ",".join("?" * len(spec_ids)) |
854 | | - rows = self._conn.execute( |
855 | | - f"SELECT spec_id, comp_id FROM {self.table} WHERE spec_id IN ({placeholders})", |
856 | | - spec_ids |
857 | | - ).fetchall() |
858 | | - return pd.DataFrame(rows, columns=["spec_id", "comp_id"]) |
859 | | - |
860 | | - def get_specs_for_comp(self, comp_id: str) -> List[int]: |
861 | | - """Return list of spec_ids for a given comp_id.""" |
862 | | - rows = self._conn.execute(f"SELECT spec_id FROM {self.table} WHERE comp_id = ?", (comp_id,)).fetchall() |
863 | | - return [r[0] for r in rows] |
864 | | - |
865 | | - def get_all_mappings(self) -> pd.DataFrame: |
866 | | - """Return all spec_id <-> comp_id mappings as a DataFrame.""" |
867 | | - rows = self._conn.execute(f"SELECT spec_id, comp_id FROM {self.table}").fetchall() |
868 | | - return pd.DataFrame(rows, columns=["spec_id", "comp_id"]) |
869 | | - |
870 | | - |
871 | | -# ================================================== |
872 | | -# Integrations with SpectralDatabase |
873 | | -# ================================================== |
874 | | - |
875 | | -def map_from_spectraldb_metadata( |
876 | | - spectral_db_sqlite_path: str, |
877 | | - mapping_sqlite_path: Optional[str] = None, |
878 | | - compounds_sqlite_path: Optional[str] = None, |
879 | | - spectra_table: str = "spectra", |
880 | | - compound_table: str = "compounds", |
881 | | - mapping_table: str = "spec_to_comp", |
882 | | - *, |
883 | | - create_missing_compounds: bool = True |
884 | | -) -> Tuple[int, int]: |
885 | | - """ |
886 | | - Read spectra metadata (expects 'inchikey' in metadata), create comp_id (inchikey14), |
887 | | - populate spec_to_comp, and optionally upsert minimal compounds. |
888 | | -
|
889 | | - Returns: (n_mapped, n_new_compounds) |
890 | | - """ |
891 | | - # We do not import the class to avoid circular imports; use sqlite directly. |
892 | | - s_conn = sqlite3.connect(spectral_db_sqlite_path) |
893 | | - s_conn.row_factory = sqlite3.Row |
894 | | - |
895 | | - map_db_path = mapping_sqlite_path or spectral_db_sqlite_path |
896 | | - c_db_path = compounds_sqlite_path or spectral_db_sqlite_path |
897 | | - |
898 | | - mapper = SpecToCompoundMap(map_db_path, table=mapping_table) |
899 | | - compdb = CompoundDatabase(c_db_path, table=compound_table) |
900 | | - |
901 | | - # Discover which columns exist in the spectra table |
902 | | - cols = {r[1] for r in s_conn.execute(f"PRAGMA table_info({spectra_table})").fetchall()} |
903 | | - want = ["spec_id", "inchikey", "smiles", "inchi", "classyfire_class", "classyfire_superclass"] |
904 | | - have = [c for c in want if c in cols] |
905 | | - select_cols = ", ".join(have) |
906 | | - |
907 | | - rows = s_conn.execute(f"SELECT {select_cols} FROM {spectra_table}").fetchall() |
908 | | - |
909 | | - to_link: List[Tuple[int, str]] = [] |
910 | | - new_comp_rows: List[Dict[str, Any]] = [] |
911 | | - |
912 | | - for r in rows: |
913 | | - r = dict(r) |
914 | | - spec_id = int(r["spec_id"]) |
915 | | - ik_full = r.get("inchikey") |
916 | | - if not ik_full: |
917 | | - continue |
918 | | - comp_id = inchikey14_from_full(ik_full) |
919 | | - if not comp_id: |
920 | | - continue |
921 | | - to_link.append((spec_id, comp_id)) |
922 | | - |
923 | | - if create_missing_compounds: |
924 | | - new_comp_rows.append({ |
925 | | - "smiles": r.get("smiles"), |
926 | | - "inchi": r.get("inchi"), |
927 | | - "inchikey": ik_full, |
928 | | - "classyfire_class": r.get("classyfire_class"), |
929 | | - "classyfire_superclass": r.get("classyfire_superclass"), |
930 | | - "fingerprint": None, # backfill later |
931 | | - }) |
932 | | - |
933 | | - # Bulk linking |
934 | | - if to_link: |
935 | | - mapper.link_many(to_link) |
936 | | - |
937 | | - # Upsert compounds |
938 | | - n_new_compounds = 0 |
939 | | - if create_missing_compounds and new_comp_rows: |
940 | | - # Deduplicate by comp_id to avoid redundant upserts |
941 | | - seen: set[str] = set() |
942 | | - dedup_rows: List[Dict[str, Any]] = [] |
943 | | - for r in new_comp_rows: |
944 | | - cid = inchikey14_from_full(r["inchikey"]) |
945 | | - if cid and cid not in seen: |
946 | | - seen.add(cid) |
947 | | - dedup_rows.append(r) |
948 | | - before = compdb.sql_query(f"SELECT COUNT(*) AS n FROM {compound_table}")["n"].iloc[0] |
949 | | - compdb.upsert_many(dedup_rows) |
950 | | - after = compdb.sql_query(f"SELECT COUNT(*) AS n FROM {compound_table}")["n"].iloc[0] |
951 | | - n_new_compounds = int(after - before) |
952 | | - |
953 | | - n_mapped = len(to_link) |
954 | | - |
955 | | - # Close connections |
956 | | - mapper.close() |
957 | | - compdb.close() |
958 | | - s_conn.close() |
959 | | - |
960 | | - return n_mapped, n_new_compounds |
961 | | - |
962 | | - |
963 | | -def get_unique_compounds_from_spectraldb( |
964 | | - spectral_db_sqlite_path: str, |
965 | | - spectra_table: str = "spectra", |
966 | | - external_meta: Optional[pd.DataFrame] = None, |
967 | | - external_key_col: str = "inchikey14" |
968 | | -) -> pd.DataFrame: |
969 | | - """ |
970 | | - Return a DataFrame of unique compounds present in the spectral DB, inferred via inchikey → inchikey14. |
971 | | - Columns: inchikey14, inchikey (full), n_spectra. If `external_meta` is provided, |
972 | | - it will be left-joined on `external_key_col` (default 'inchikey14'). |
973 | | - """ |
974 | | - conn = sqlite3.connect(spectral_db_sqlite_path) |
975 | | - conn.row_factory = sqlite3.Row |
976 | | - |
977 | | - # pull spec_id + inchikey from spectra |
978 | | - df = pd.read_sql_query(f"SELECT spec_id, inchikey FROM {spectra_table}", conn) |
979 | | - conn.close() |
980 | | - |
981 | | - if df.empty: |
982 | | - base = pd.DataFrame(columns=["inchikey14", "inchikey", "n_spectra"]) |
983 | | - if external_meta is not None: |
984 | | - return base.merge(external_meta, how="left", left_on="inchikey14", right_on=external_key_col) |
985 | | - return base |
986 | | - |
987 | | - # Compute inchikey14 |
988 | | - ik14 = df["inchikey"].fillna("").map(inchikey14_from_full) |
989 | | - df["inchikey14"] = ik14 |
990 | | - |
991 | | - # Aggregate |
992 | | - agg = ( |
993 | | - df.dropna(subset=["inchikey14"]) |
994 | | - .groupby(["inchikey14"], as_index=False) |
995 | | - .agg(n_spectra=("spec_id", "count"), |
996 | | - inchikey=("inchikey", "first")) # first full key seen |
997 | | - ) |
998 | | - |
999 | | - # Optional join with external meta |
1000 | | - if external_meta is not None and not external_meta.empty: |
1001 | | - agg = agg.merge(external_meta, how="left", left_on="inchikey14", right_on=external_key_col) |
1002 | | - |
1003 | | - # Order by prevalence |
1004 | | - agg = agg.sort_values("n_spectra", ascending=False).reset_index(drop=True) |
1005 | | - return agg |
0 commit comments