Skip to content

Commit 63a429b

Browse files
committed
replace dataframe addition method + add tests
1 parent cbd51f6 commit 63a429b

File tree

2 files changed

+224
-44
lines changed

2 files changed

+224
-44
lines changed

ms2query/database/compound_database.py

Lines changed: 111 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -321,68 +321,137 @@ def upsert_many(self, rows: Iterable[Dict[str, Any]]) -> List[str]:
321321
cur.executemany(UPSERT_SQL.format(table=self.table), payloads)
322322
return comp_ids
323323

324-
def upsert_metadata_from_dataframe(
324+
325+
def overwrite_metadata_from_dataframe(
325326
self,
326327
df: pd.DataFrame,
327328
*,
328-
colmap: Optional[Dict[str, str]] = None,
329+
column_mapper: Optional[Dict[str, str]] = None,
330+
chunksize: int = 50_000,
329331
staging_table: str = "_staging_compounds",
330332
) -> dict:
331-
"""Load/update metadata (no fingerprints) via a staging table."""
333+
"""
334+
Fast initialize/replace of the compounds table from a wide DataFrame.
335+
336+
Parameters
337+
----------
338+
df : pd.DataFrame
339+
DataFrame containing compound metadata.
340+
column_mapper : Optional[Dict[str, str]], optional
341+
Mapping of DataFrame columns to expected compound fields, by default None.
342+
chunksize : int, optional
343+
Number of rows per chunk when writing to the database, by default 50_000.
344+
345+
Pass `column_mapper` to map your DataFrame columns to the expected names.
346+
Supported mapping keys (values are your df column names):
347+
- 'comp_id' (14-char key; if present, used as-is)
348+
- 'inchikey' (full key; used to derive comp_id if no comp_id provided)
349+
- 'smiles'
350+
- 'inchi'
351+
- 'classyfire_class'
352+
- 'classyfire_superclass'
353+
354+
If you don't pass a mapper, this will auto-detect common aliases for the 14-char key:
355+
'nchikey', 'inchikey14', 'inchikey_14', 'ik14', 'comp_id'
356+
"""
332357
if df is None or df.empty:
333-
return {"rows": 0, "valid": 0, "inserted_or_updated": 0, "skipped_no_or_bad_inchikey": 0}
334-
335-
# Column mapping
358+
return {"rows": 0, "valid": 0, "written": 0, "skipped": 0}
359+
360+
# ----- resolve columns (mapper-aware with sensible defaults) -----
336361
default_map = {
362+
"comp_id": None, # optional (14-char key)
337363
"inchikey": "inchikey",
338364
"smiles": "smiles",
339365
"inchi": "inchi",
340366
"classyfire_class": "classyfire_class",
341367
"classyfire_superclass": "classyfire_superclass",
342368
}
343-
cmap = {k: (colmap.get(k) if colmap and k in colmap else v) for k, v in default_map.items()}
344-
if cmap["inchikey"] not in df.columns:
345-
raise ValueError("DataFrame must contain an 'inchikey' column (or provide colmap).")
346-
347-
# Build compact frame
348-
work = pd.DataFrame({"inchikey": df[cmap["inchikey"]].astype(str)})
349-
work["comp_id"] = work["inchikey"].map(inchikey14_from_full)
350-
351-
valid_mask = work["comp_id"].notna() & work["comp_id"].astype(str).str.len().eq(14)
352-
skipped = int((~valid_mask).sum())
353-
work = work.loc[valid_mask, ["comp_id", "inchikey"]].copy()
354-
369+
cmap = {k: (column_mapper.get(k) if column_mapper and k in column_mapper else v)
370+
for k, v in default_map.items()}
371+
372+
# Auto-detect a 14-char key if no mapping was provided for comp_id
373+
def _first_present(cols: list[str]) -> Optional[str]:
374+
for c in cols:
375+
if c in df.columns:
376+
return c
377+
return None
378+
379+
if cmap["comp_id"] is None:
380+
cmap["comp_id"] = _first_present(["nchikey", "inchikey14", "inchikey_14", "ik14", "comp_id"])
381+
382+
# We need either a 14-char key or a full inchikey (possibly via mapping)
383+
has_comp14 = cmap["comp_id"] is not None and cmap["comp_id"] in df.columns
384+
has_fullik = cmap["inchikey"] is not None and cmap["inchikey"] in df.columns
385+
if not has_comp14 and not has_fullik:
386+
raise ValueError(
387+
"DataFrame must contain either a 14-char key "
388+
"(map it via column_mapper['comp_id']) or a full 'inchikey' "
389+
"(map it via column_mapper['inchikey'])."
390+
)
391+
392+
# ----- build minimal working frame -----
393+
work = pd.DataFrame()
394+
395+
# comp_id (14-char)
396+
if has_comp14:
397+
work["comp_id"] = df[cmap["comp_id"]].astype(str).str.strip()
398+
else:
399+
# derive from full inchikey
400+
full = df[cmap["inchikey"]].astype(str).str.strip()
401+
work["comp_id"] = full.map(inchikey14_from_full)
402+
403+
# inchikey (full) if present
404+
work["inchikey"] = df[cmap["inchikey"]].astype(str).str.strip() if has_fullik else None
405+
406+
# optional metadata (mapper-aware)
355407
for k in ("smiles", "inchi", "classyfire_class", "classyfire_superclass"):
356408
src = cmap[k]
357-
work[k] = df[src] if src in df.columns else None
358-
359-
work = work.drop_duplicates(subset=["comp_id"], keep="last").reset_index(drop=True)
360-
361-
# Stage + upsert
362-
work.to_sql(staging_table, self._conn, if_exists="replace", index=False)
409+
work[k] = df[src] if (src is not None and src in df.columns) else None
410+
411+
# ----- validate / deduplicate -----
412+
comp = work["comp_id"].astype(str).str.strip()
413+
valid_mask = comp.str.len().eq(14) & comp.ne("")
414+
skipped = int((~valid_mask).sum())
415+
416+
work = (work.loc[valid_mask, ["comp_id", "inchikey", "smiles", "inchi",
417+
"classyfire_class", "classyfire_superclass"]]
418+
.drop_duplicates(subset=["comp_id"], keep="last")
419+
.reset_index(drop=True))
420+
421+
if work.empty:
422+
return {"rows": int(len(df)), "valid": 0, "written": 0, "skipped": int(skipped)}
423+
424+
# ----- bulk load: staging -> recreate main table (fast) -----
425+
cur = self._conn.cursor()
426+
# speed PRAGMAs during load
427+
cur.execute("PRAGMA synchronous=OFF")
428+
cur.execute("PRAGMA temp_store=MEMORY")
429+
cur.execute("PRAGMA cache_size=-200000")
430+
431+
# 1) write to staging with big chunks & multi-row inserts
432+
work.to_sql(staging_table, self._conn, if_exists="replace", index=False,
433+
chunksize=chunksize, method="multi")
434+
435+
# 2) atomically recreate the target table with schema + copy from staging
363436
with self._tx() as cur:
437+
cur.execute(f"DROP TABLE IF EXISTS {self.table}")
438+
cur.executescript(SCHEMA_SQL.format(table=self.table))
364439
cur.execute(f"""
365440
INSERT INTO {self.table} (
366-
comp_id, smiles, inchi, inchikey, classyfire_class, classyfire_superclass
441+
comp_id, smiles, inchi, inchikey,
442+
classyfire_class, classyfire_superclass
367443
)
368-
SELECT comp_id, smiles, inchi, inchikey, classyfire_class, classyfire_superclass
444+
SELECT comp_id, smiles, inchi, inchikey,
445+
classyfire_class, classyfire_superclass
369446
FROM {staging_table}
370-
ON CONFLICT(comp_id) DO UPDATE SET
371-
smiles = COALESCE(excluded.smiles, {self.table}.smiles),
372-
inchi = COALESCE(excluded.inchi, {self.table}.inchi),
373-
inchikey = COALESCE(excluded.inchikey, {self.table}.inchikey),
374-
classyfire_class = COALESCE(excluded.classyfire_class, {self.table}.classyfire_class),
375-
classyfire_superclass = COALESCE(excluded.classyfire_superclass, {self.table}.classyfire_superclass)
376447
""")
377-
affected = cur.rowcount or 0
448+
written = cur.rowcount or len(work)
378449
cur.execute(f"DROP TABLE IF EXISTS {staging_table}")
379-
380-
return {
381-
"rows": int(len(df)),
382-
"valid": int(len(work)),
383-
"inserted_or_updated": int(affected),
384-
"skipped_no_or_bad_inchikey": int(skipped),
385-
}
450+
451+
# Make sure indexes & settings table exist (idempotent)
452+
self._ensure_schema_and_settings()
453+
454+
return {"rows": int(len(df)), "valid": int(len(work)), "written": int(written), "skipped": int(skipped)}
386455

387456
def compute_fingerprints(
388457
self,
@@ -447,7 +516,8 @@ def get_fingerprints(self, comp_id_list: List[str]):
447516
for cid in comp_id_list:
448517
r = next((row for row in rows if row["comp_id"] == cid), None)
449518
if r is None:
450-
out.append(None); continue
519+
out.append(None)
520+
continue
451521
dense_blob = r["fingerprint_dense"] or b""
452522
bits_blob = r["fingerprint_bits"] or b""
453523
counts_blob = r["fingerprint_counts"] or b""

tests/test_compound_database.py

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import sqlite3
22
from pathlib import Path
33
import numpy as np
4+
import pandas as pd
45
import pytest
5-
from ms2query.data_processing import compute_morgan_fingerprints
6+
from ms2query.data_processing import compute_morgan_fingerprints, inchikey14_from_full
67
from ms2query.database.compound_database import (
78
CompoundDatabase,
89
)
@@ -64,7 +65,6 @@ def test_compute_fingerprints_contract():
6465
bits, counts = fp
6566
assert isinstance(bits, np.ndarray) and bits.dtype == np.uint32
6667
assert isinstance(counts, np.ndarray)
67-
# counts are usually integer-like (could be float if you later scale)
6868
assert counts.ndim == 1
6969

7070
# -------------------------
@@ -187,4 +187,114 @@ def test_compute_fingerprints_method(count, sparse):
187187
else:
188188
assert np.allclose(fps_directly[0], fps_after[0])
189189
cdb.close()
190-
190+
191+
192+
def test_overwrite_metadata_from_dataframe_basic_and_mapping(tmp_path):
193+
db_path = tmp_path / "compounds.sqlite"
194+
cdb = CompoundDatabase(str(db_path))
195+
196+
# Wide DF with aliases + extras; includes:
197+
# - valid 14-char keys via 'nchikey'
198+
# - one invalid key (too short) -> skipped
199+
# - one duplicate comp_id -> keep last
200+
df = pd.DataFrame({
201+
"nchikey": ["AAAQFGUYHFJNHI", "AABFWJDLCCDJJN", "SHORTKEY", "AABFWJDLCCDJJN"],
202+
"smiles": ["S1", "S2", "S_bad", "S2_override"],
203+
"cf_class": ["C1", "C2", "C_bad", "C2_override"],
204+
"cf_superclass": ["SC1", "SC2", "SC_bad", "SC2_override"],
205+
"mass": [423.146, 324.126, 0.0, 999.0], # extra column to be ignored
206+
})
207+
208+
stats = cdb.overwrite_metadata_from_dataframe(
209+
df,
210+
column_mapper={ # map aliases -> expected names
211+
"comp_id": "nchikey",
212+
"smiles": "smiles",
213+
"classyfire_class": "cf_class",
214+
"classyfire_superclass": "cf_superclass",
215+
}
216+
)
217+
218+
# Rows: 4 incoming, 1 invalid (SHORTKEY) -> skipped=1
219+
# Valid comp_ids: AAAQFGUYHFJNHI, AABFWJDLCCDJJN (duplicate -> keep last) => written=2
220+
assert stats["rows"] == 4
221+
assert stats["skipped"] == 1
222+
assert stats["valid"] == 2
223+
assert stats["written"] == 2
224+
225+
# Check DB content
226+
df_db = pd.read_sql_query("SELECT comp_id, smiles, classyfire_class, classyfire_superclass, inchikey, inchi FROM compounds", cdb._conn)
227+
assert set(df_db["comp_id"]) == {"AAAQFGUYHFJNHI", "AABFWJDLCCDJJN"}
228+
229+
# Row without full inchikey provided -> stored as NULL. Inchi not provided -> NULL
230+
assert df_db.loc[df_db["comp_id"] == "AAAQFGUYHFJNHI", "inchikey"].iloc[0] in (None, np.nan, "")
231+
assert df_db.loc[df_db["comp_id"] == "AAAQFGUYHFJNHI", "inchi"].iloc[0] in (None, np.nan, "")
232+
233+
# “keep last” behavior for duplicate comp_id
234+
r = df_db.set_index("comp_id").loc["AABFWJDLCCDJJN"]
235+
assert r["smiles"] == "S2_override"
236+
assert r["classyfire_class"] == "C2_override"
237+
assert r["classyfire_superclass"] == "SC2_override"
238+
239+
# Settings table is intact and readable
240+
settings = cdb.get_fingerprint_settings()
241+
assert {"nbits", "radius", "sparse", "count", "dtype"} <= set(settings.keys())
242+
243+
cdb.close()
244+
245+
246+
def test_overwrite_metadata_from_dataframe_derive_comp_id_and_true_replace(tmp_path):
247+
db_path = tmp_path / "compounds.sqlite"
248+
cdb = CompoundDatabase(str(db_path))
249+
250+
# First load: only full InChIKeys (custom column name), comp_id must be derived
251+
df1 = pd.DataFrame({
252+
"IK_FULL": [
253+
"BQJCRHHNABKAKU-KBQPJGBKSA-N",
254+
"BSYNRYMUTXBXSQ-UHFFFAOYSA-N",
255+
],
256+
"smiles": ["CCO", "O=C=O"],
257+
"inchi": ["InChI=1S/C2H6O/c1-2-3/h3H,2H2,1H3", "InChI=1S/CO2/c2-1-3"],
258+
"cf_class": ["Alcohols", "Carbon oxides"],
259+
"cf_superclass": ["Organooxygen compounds", "Inorganic compounds"],
260+
})
261+
262+
stats1 = cdb.overwrite_metadata_from_dataframe(
263+
df1,
264+
column_mapper={
265+
"inchikey": "IK_FULL", # derive comp_id from full IK
266+
"smiles": "smiles",
267+
"inchi": "inchi",
268+
"classyfire_class": "cf_class",
269+
"classyfire_superclass": "cf_superclass",
270+
}
271+
)
272+
assert stats1["written"] == 2
273+
df_db1 = pd.read_sql_query("SELECT comp_id, inchikey, smiles FROM compounds ORDER BY comp_id", cdb._conn)
274+
# comp_id equals inchikey14_from_full(inchikey)
275+
for _, row in df_db1.iterrows():
276+
assert row["comp_id"] == inchikey14_from_full(row["inchikey"])
277+
278+
# Second load: replace with a different set -> previous rows must disappear
279+
df2 = pd.DataFrame({
280+
"IK_FULL": ["AAOVKJBEBIDNHE-UHFFFAOYSA-N"],
281+
"smiles": ["CC(=O)O"],
282+
"cf_class": ["Carboxylic acids"],
283+
"cf_superclass": ["Organooxygen compounds"],
284+
})
285+
stats2 = cdb.overwrite_metadata_from_dataframe(
286+
df2,
287+
column_mapper={
288+
"inchikey": "IK_FULL",
289+
"smiles": "smiles",
290+
"classyfire_class": "cf_class",
291+
"classyfire_superclass": "cf_superclass",
292+
}
293+
)
294+
assert stats2["written"] == 1
295+
df_db2 = pd.read_sql_query("SELECT comp_id, inchikey, smiles FROM compounds", cdb._conn)
296+
assert len(df_db2) == 1
297+
assert df_db2.iloc[0]["comp_id"] == inchikey14_from_full(df_db2.iloc[0]["inchikey"])
298+
assert set(df_db2["smiles"]) == {"CC(=O)O"} # previous rows gone (true replace)
299+
300+
cdb.close()

0 commit comments

Comments
 (0)