Skip to content

Commit ff6e2b6

Browse files
committed
fix by_ids methods and linting
1 parent a59cc9a commit ff6e2b6

File tree

1 file changed

+85
-39
lines changed

1 file changed

+85
-39
lines changed

ms2query/database/spectral_database.py

Lines changed: 85 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,19 @@ def _normalize_metadata(md: Dict[str, Any], fields: Iterable[str]) -> Dict[str,
5757
class SpectralDatabase:
5858
sqlite_path: str
5959
table: str = "spectra"
60-
metadata_fields: List[str] = field(default_factory=lambda: [
61-
"precursor_mz", "ionmode", "smiles", "inchikey", "inchi", "name",
62-
"instrument_type", "adduct", "collision_energy"
63-
])
60+
metadata_fields: List[str] = field(
61+
default_factory=lambda: [
62+
"precursor_mz",
63+
"ionmode",
64+
"smiles",
65+
"inchikey",
66+
"inchi",
67+
"name",
68+
"instrument_type",
69+
"adduct",
70+
"collision_energy",
71+
]
72+
)
6473
spectrum_sum_normalization_for_embedding: bool = True
6574
_conn: sqlite3.Connection = field(init=False, repr=False)
6675
_ms2ds_model_path: Optional[str] = field(default=None, repr=False)
@@ -81,7 +90,8 @@ def add_spectra(self, spectra: List[Spectrum]) -> List[str]:
8190

8291
cur = self._conn.cursor()
8392
# Bulk-load speed PRAGMAs (safe for single-user/batch ingest)
84-
cur.executescript("""
93+
cur.executescript(
94+
"""
8595
PRAGMA journal_mode=WAL;
8696
PRAGMA synchronous=OFF;
8797
PRAGMA temp_store=MEMORY;
@@ -130,14 +140,14 @@ def ids(self) -> List[str]:
130140
rows = cur.execute(f"SELECT spec_id FROM {self.table}").fetchall()
131141
return [str(row["spec_id"]) for row in rows]
132142

133-
def get_spectra_by_ids(self, specIDs: List[str]) -> List[Spectrum]:
134-
"""Retrieve full Spectrum objects for given specIDs (order preserved, missing IDs skipped)."""
143+
def get_spectra_by_ids(self, spec_ids: List[str]) -> List[Spectrum]:
144+
"""Retrieve full Spectrum objects for given spec_ids (order preserved, missing IDs skipped)."""
135145
rows = self._fetch_rows_by_ids(
136-
specIDs, cols="spec_id, mz_blob, intensity_blob, n_peaks, " + ", ".join(self.metadata_fields))
146+
spec_ids, cols="spec_id, mz_blob, intensity_blob, n_peaks, " + ", ".join(self.metadata_fields))
137147
by_id = {row["spec_id"]: row for row in rows}
138148

139149
result: List[Spectrum] = []
140-
for sid in specIDs:
150+
for sid in spec_ids:
141151
row = by_id.get(sid)
142152
if row is None:
143153
continue
@@ -149,13 +159,19 @@ def get_spectra_by_ids(self, specIDs: List[str]) -> List[Spectrum]:
149159
result.append(Spectrum(mz=mz, intensities=inten, metadata=md))
150160
return result
151161

152-
def get_fragments_by_ids(self, specIDs: List[str]) -> List[Tuple[np.ndarray, np.ndarray]]:
153-
"""Retrieve (mz, intensity) arrays for given specIDs (order preserved, missing IDs skipped)."""
154-
rows = self._fetch_rows_by_ids(specIDs, cols="spec_id, mz_blob, intensity_blob, n_peaks")
162+
def get_fragments_by_ids(self, spec_ids: List[str]) -> List[Tuple[np.ndarray, np.ndarray]]:
163+
"""
164+
Retrieve (mz, intensity) arrays for given spec_ids.
165+
166+
Order is preserved with respect to `spec_ids`.
167+
Missing IDs are skipped.
168+
"""
169+
cols = "spec_id, mz_blob, intensity_blob, n_peaks"
170+
rows = self._fetch_rows_by_ids(spec_ids, cols=cols)
155171
by_id = {row["spec_id"]: row for row in rows}
156172

157173
out: List[Tuple[np.ndarray, np.ndarray]] = []
158-
for sid in specIDs:
174+
for sid in spec_ids:
159175
row = by_id.get(sid)
160176
if row is None:
161177
continue
@@ -165,16 +181,34 @@ def get_fragments_by_ids(self, specIDs: List[str]) -> List[Tuple[np.ndarray, np.
165181
out.append((mz, inten))
166182
return out
167183

168-
def get_metadata_by_ids(self, specIDs: List[str]) -> pd.DataFrame:
169-
"""Retrieve metadata for given specIDs (order preserved)."""
184+
def get_metadata_by_ids(self, spec_ids: List[str]) -> pd.DataFrame:
185+
"""
186+
Retrieve metadata for given spec_ids.
187+
188+
Returns a DataFrame with **one row per requested spec_id** in the same
189+
order as `spec_ids`. If a spec_id is not present in the database, a row
190+
with that spec_id and metadata columns set to None/NaN is returned.
191+
"""
170192
cols = ["spec_id"] + self.metadata_fields
171-
rows = self._fetch_rows_by_ids(specIDs, cols=", ".join(cols))
172-
df = pd.DataFrame(rows, columns=cols)
173-
if not df.empty:
174-
order = {sid: i for i, sid in enumerate(specIDs)}
175-
df["__order"] = df["spec_id"].map(order)
176-
df = df.sort_values("__order").drop(columns="__order").reset_index(drop=True)
177-
return df
193+
if not spec_ids:
194+
return pd.DataFrame(columns=cols)
195+
196+
rows = self._fetch_rows_by_ids(spec_ids, cols=", ".join(cols))
197+
by_id = {row["spec_id"]: row for row in rows}
198+
199+
records: List[Dict[str, Any]] = []
200+
for sid in spec_ids:
201+
row = by_id.get(sid)
202+
if row is None:
203+
rec = {"spec_id": sid}
204+
rec.update({k: None for k in self.metadata_fields})
205+
else:
206+
rec = {"spec_id": sid}
207+
for k in self.metadata_fields:
208+
rec[k] = row[k]
209+
records.append(rec)
210+
211+
return pd.DataFrame.from_records(records, columns=cols)
178212

179213
def sql_query(self, query: str) -> pd.DataFrame:
180214
"""Run a raw SQL SELECT and return a DataFrame."""
@@ -225,7 +259,6 @@ def compute_embeddings_to_sqlite(
225259
- Uses `matchms.Spectrum` objects reconstructed from the stored peaks & metadata.
226260
- Stores raw float32 vectors (no extra header) with their dimension `d`.
227261
"""
228-
# TODO: add batch_size to speed up?
229262
spectra_table = spectra_table or self.table
230263
self._ensure_schema() # spectra schema
231264
self.ensure_embeddings_schema(embeddings_table)
@@ -254,7 +287,9 @@ def compute_embeddings_to_sqlite(
254287
model = self.load_ms2deepscore_model(model_path)
255288

256289
inserted = 0
257-
buf: List[Tuple[str, bytes, bytes, int, float, str, Optional[int]]] = []
290+
buf: List[
291+
Tuple[str, bytes, bytes, int, float, str, Optional[int]]
292+
] = []
258293
done_since_commit = 0
259294

260295
def flush(batch) -> int:
@@ -265,24 +300,35 @@ def flush(batch) -> int:
265300
for sid, mz_blob, it_blob, n_peaks, prec_mz, ionmode, charge in batch:
266301
mz = _from_float32_bytes(mz_blob, int(n_peaks))
267302
it = _from_float32_bytes(it_blob, int(n_peaks))
268-
spectrum = Spectrum(mz=mz, intensities=it, metadata={
269-
"precursor_mz": float(prec_mz) if prec_mz is not None else None,
270-
"ionmode": ionmode,
271-
"charge": charge,
272-
"spec_id": sid,
273-
})
303+
spectrum = Spectrum(
304+
mz=mz,
305+
intensities=it,
306+
metadata={
307+
"precursor_mz": float(prec_mz) if prec_mz is not None else None,
308+
"ionmode": ionmode,
309+
"charge": charge,
310+
"spec_id": sid,
311+
},
312+
)
274313
specs.append(spectrum)
275314
sids.append(sid)
276315

277316
embeddings = compute_spectra_embeddings(
278-
model, specs,
279-
normalize_spectrum=self.spectrum_sum_normalization_for_embedding
280-
)
317+
model,
318+
specs,
319+
normalize_spectrum=self.spectrum_sum_normalization_for_embedding,
320+
)
281321
dim = int(embeddings.shape[1])
282-
q = f"INSERT OR REPLACE INTO {embeddings_table} (spec_id, d, vec) VALUES (?, ?, ?);"
322+
q = (
323+
f"INSERT OR REPLACE INTO {embeddings_table} "
324+
f"(spec_id, d, vec) VALUES (?, ?, ?);"
325+
)
283326
with self._conn:
284327
for sid, embedding in zip(sids, embeddings):
285-
self._conn.execute(q, (sid, dim, sqlite3.Binary(_as_float32_bytes(embedding))))
328+
self._conn.execute(
329+
q,
330+
(sid, dim, sqlite3.Binary(_as_float32_bytes(embedding))),
331+
)
286332
return len(batch)
287333

288334
while True:
@@ -360,13 +406,13 @@ def connection(self) -> sqlite3.Connection:
360406
return self._conn
361407
# ---------- internal ----------
362408

363-
def _fetch_rows_by_ids(self, specIDs: List[str], cols: str) -> List[sqlite3.Row]:
364-
if not specIDs:
409+
def _fetch_rows_by_ids(self, spec_ids: List[str], cols: str) -> List[sqlite3.Row]:
410+
if not spec_ids:
365411
return []
366-
placeholders = ",".join("?" for _ in specIDs)
412+
placeholders = ",".join("?" for _ in spec_ids)
367413
sql = f"SELECT {cols} FROM {self.table} WHERE spec_id IN ({placeholders})"
368414
cur = self._conn.cursor()
369-
return cur.execute(sql, specIDs).fetchall()
415+
return cur.execute(sql, spec_ids).fetchall()
370416

371417
def _ensure_schema(self):
372418
cur = self._conn.cursor()

0 commit comments

Comments
 (0)