@@ -57,10 +57,19 @@ def _normalize_metadata(md: Dict[str, Any], fields: Iterable[str]) -> Dict[str,
5757class SpectralDatabase :
5858 sqlite_path : str
5959 table : str = "spectra"
60- metadata_fields : List [str ] = field (default_factory = lambda : [
61- "precursor_mz" , "ionmode" , "smiles" , "inchikey" , "inchi" , "name" ,
62- "instrument_type" , "adduct" , "collision_energy"
63- ])
60+ metadata_fields : List [str ] = field (
61+ default_factory = lambda : [
62+ "precursor_mz" ,
63+ "ionmode" ,
64+ "smiles" ,
65+ "inchikey" ,
66+ "inchi" ,
67+ "name" ,
68+ "instrument_type" ,
69+ "adduct" ,
70+ "collision_energy" ,
71+ ]
72+ )
6473 spectrum_sum_normalization_for_embedding : bool = True
6574 _conn : sqlite3 .Connection = field (init = False , repr = False )
6675 _ms2ds_model_path : Optional [str ] = field (default = None , repr = False )
@@ -81,7 +90,8 @@ def add_spectra(self, spectra: List[Spectrum]) -> List[str]:
8190
8291 cur = self ._conn .cursor ()
8392 # Bulk-load speed PRAGMAs (safe for single-user/batch ingest)
84- cur .executescript ("""
93+ cur .executescript (
94+ """
8595 PRAGMA journal_mode=WAL;
8696 PRAGMA synchronous=OFF;
8797 PRAGMA temp_store=MEMORY;
@@ -130,14 +140,14 @@ def ids(self) -> List[str]:
130140 rows = cur .execute (f"SELECT spec_id FROM { self .table } " ).fetchall ()
131141 return [str (row ["spec_id" ]) for row in rows ]
132142
133- def get_spectra_by_ids (self , specIDs : List [str ]) -> List [Spectrum ]:
134- """Retrieve full Spectrum objects for given specIDs (order preserved, missing IDs skipped)."""
143+ def get_spectra_by_ids (self , spec_ids : List [str ]) -> List [Spectrum ]:
144+ """Retrieve full Spectrum objects for given spec_ids (order preserved, missing IDs skipped)."""
135145 rows = self ._fetch_rows_by_ids (
136- specIDs , cols = "spec_id, mz_blob, intensity_blob, n_peaks, " + ", " .join (self .metadata_fields ))
146+ spec_ids , cols = "spec_id, mz_blob, intensity_blob, n_peaks, " + ", " .join (self .metadata_fields ))
137147 by_id = {row ["spec_id" ]: row for row in rows }
138148
139149 result : List [Spectrum ] = []
140- for sid in specIDs :
150+ for sid in spec_ids :
141151 row = by_id .get (sid )
142152 if row is None :
143153 continue
@@ -149,13 +159,19 @@ def get_spectra_by_ids(self, specIDs: List[str]) -> List[Spectrum]:
149159 result .append (Spectrum (mz = mz , intensities = inten , metadata = md ))
150160 return result
151161
152- def get_fragments_by_ids (self , specIDs : List [str ]) -> List [Tuple [np .ndarray , np .ndarray ]]:
153- """Retrieve (mz, intensity) arrays for given specIDs (order preserved, missing IDs skipped)."""
154- rows = self ._fetch_rows_by_ids (specIDs , cols = "spec_id, mz_blob, intensity_blob, n_peaks" )
162+ def get_fragments_by_ids (self , spec_ids : List [str ]) -> List [Tuple [np .ndarray , np .ndarray ]]:
163+ """
164+ Retrieve (mz, intensity) arrays for given spec_ids.
165+
166+ Order is preserved with respect to `spec_ids`.
167+ Missing IDs are skipped.
168+ """
169+ cols = "spec_id, mz_blob, intensity_blob, n_peaks"
170+ rows = self ._fetch_rows_by_ids (spec_ids , cols = cols )
155171 by_id = {row ["spec_id" ]: row for row in rows }
156172
157173 out : List [Tuple [np .ndarray , np .ndarray ]] = []
158- for sid in specIDs :
174+ for sid in spec_ids :
159175 row = by_id .get (sid )
160176 if row is None :
161177 continue
@@ -165,16 +181,34 @@ def get_fragments_by_ids(self, specIDs: List[str]) -> List[Tuple[np.ndarray, np.
165181 out .append ((mz , inten ))
166182 return out
167183
168- def get_metadata_by_ids (self , specIDs : List [str ]) -> pd .DataFrame :
169- """Retrieve metadata for given specIDs (order preserved)."""
184+ def get_metadata_by_ids (self , spec_ids : List [str ]) -> pd .DataFrame :
185+ """
186+ Retrieve metadata for given spec_ids.
187+
188+ Returns a DataFrame with **one row per requested spec_id** in the same
189+ order as `spec_ids`. If a spec_id is not present in the database, a row
190+ with that spec_id and metadata columns set to None/NaN is returned.
191+ """
170192 cols = ["spec_id" ] + self .metadata_fields
171- rows = self ._fetch_rows_by_ids (specIDs , cols = ", " .join (cols ))
172- df = pd .DataFrame (rows , columns = cols )
173- if not df .empty :
174- order = {sid : i for i , sid in enumerate (specIDs )}
175- df ["__order" ] = df ["spec_id" ].map (order )
176- df = df .sort_values ("__order" ).drop (columns = "__order" ).reset_index (drop = True )
177- return df
193+ if not spec_ids :
194+ return pd .DataFrame (columns = cols )
195+
196+ rows = self ._fetch_rows_by_ids (spec_ids , cols = ", " .join (cols ))
197+ by_id = {row ["spec_id" ]: row for row in rows }
198+
199+ records : List [Dict [str , Any ]] = []
200+ for sid in spec_ids :
201+ row = by_id .get (sid )
202+ if row is None :
203+ rec = {"spec_id" : sid }
204+ rec .update ({k : None for k in self .metadata_fields })
205+ else :
206+ rec = {"spec_id" : sid }
207+ for k in self .metadata_fields :
208+ rec [k ] = row [k ]
209+ records .append (rec )
210+
211+ return pd .DataFrame .from_records (records , columns = cols )
178212
179213 def sql_query (self , query : str ) -> pd .DataFrame :
180214 """Run a raw SQL SELECT and return a DataFrame."""
@@ -225,7 +259,6 @@ def compute_embeddings_to_sqlite(
225259 - Uses `matchms.Spectrum` objects reconstructed from the stored peaks & metadata.
226260 - Stores raw float32 vectors (no extra header) with their dimension `d`.
227261 """
228- # TODO: add batch_size to speed up?
229262 spectra_table = spectra_table or self .table
230263 self ._ensure_schema () # spectra schema
231264 self .ensure_embeddings_schema (embeddings_table )
@@ -254,7 +287,9 @@ def compute_embeddings_to_sqlite(
254287 model = self .load_ms2deepscore_model (model_path )
255288
256289 inserted = 0
257- buf : List [Tuple [str , bytes , bytes , int , float , str , Optional [int ]]] = []
290+ buf : List [
291+ Tuple [str , bytes , bytes , int , float , str , Optional [int ]]
292+ ] = []
258293 done_since_commit = 0
259294
260295 def flush (batch ) -> int :
@@ -265,24 +300,35 @@ def flush(batch) -> int:
265300 for sid , mz_blob , it_blob , n_peaks , prec_mz , ionmode , charge in batch :
266301 mz = _from_float32_bytes (mz_blob , int (n_peaks ))
267302 it = _from_float32_bytes (it_blob , int (n_peaks ))
268- spectrum = Spectrum (mz = mz , intensities = it , metadata = {
269- "precursor_mz" : float (prec_mz ) if prec_mz is not None else None ,
270- "ionmode" : ionmode ,
271- "charge" : charge ,
272- "spec_id" : sid ,
273- })
303+ spectrum = Spectrum (
304+ mz = mz ,
305+ intensities = it ,
306+ metadata = {
307+ "precursor_mz" : float (prec_mz ) if prec_mz is not None else None ,
308+ "ionmode" : ionmode ,
309+ "charge" : charge ,
310+ "spec_id" : sid ,
311+ },
312+ )
274313 specs .append (spectrum )
275314 sids .append (sid )
276315
277316 embeddings = compute_spectra_embeddings (
278- model , specs ,
279- normalize_spectrum = self .spectrum_sum_normalization_for_embedding
280- )
317+ model ,
318+ specs ,
319+ normalize_spectrum = self .spectrum_sum_normalization_for_embedding ,
320+ )
281321 dim = int (embeddings .shape [1 ])
282- q = f"INSERT OR REPLACE INTO { embeddings_table } (spec_id, d, vec) VALUES (?, ?, ?);"
322+ q = (
323+ f"INSERT OR REPLACE INTO { embeddings_table } "
324+ f"(spec_id, d, vec) VALUES (?, ?, ?);"
325+ )
283326 with self ._conn :
284327 for sid , embedding in zip (sids , embeddings ):
285- self ._conn .execute (q , (sid , dim , sqlite3 .Binary (_as_float32_bytes (embedding ))))
328+ self ._conn .execute (
329+ q ,
330+ (sid , dim , sqlite3 .Binary (_as_float32_bytes (embedding ))),
331+ )
286332 return len (batch )
287333
288334 while True :
@@ -360,13 +406,13 @@ def connection(self) -> sqlite3.Connection:
360406 return self ._conn
361407 # ---------- internal ----------
362408
363- def _fetch_rows_by_ids (self , specIDs : List [str ], cols : str ) -> List [sqlite3 .Row ]:
364- if not specIDs :
409+ def _fetch_rows_by_ids (self , spec_ids : List [str ], cols : str ) -> List [sqlite3 .Row ]:
410+ if not spec_ids :
365411 return []
366- placeholders = "," .join ("?" for _ in specIDs )
412+ placeholders = "," .join ("?" for _ in spec_ids )
367413 sql = f"SELECT { cols } FROM { self .table } WHERE spec_id IN ({ placeholders } )"
368414 cur = self ._conn .cursor ()
369- return cur .execute (sql , specIDs ).fetchall ()
415+ return cur .execute (sql , spec_ids ).fetchall ()
370416
371417 def _ensure_schema (self ):
372418 cur = self ._conn .cursor ()
0 commit comments