Skip to content

Commit dec423d

Browse files
committed
remove redundant normalization calls/functions
1 parent 5f5c7c7 commit dec423d

File tree

4 files changed

+50
-34
lines changed

4 files changed

+50
-34
lines changed

ms2query/database/ann_vector_index.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def __init__(self, dim: int = 500):
115115
self._index = None
116116
self._comp_ids: Optional[np.ndarray] = None
117117

118-
# ---------- direct build from arrays ----------
119118
def build_index(
120119
self,
121120
vectors: np.ndarray,
@@ -124,15 +123,25 @@ def build_index(
124123
M: int = 16,
125124
ef_construction: int = 200,
126125
post_init_ef: int = 200,
127-
assume_normalized: bool = True,
128126
) -> None:
127+
"""Build index from dense vectors and spec_ids.
128+
129+
Parameters
130+
----------
131+
vectors : np.ndarray
132+
2D array of shape (N, dim) with float32 vectors.
133+
spec_ids : Iterable[str]
134+
Iterable of spec_id strings of length N.
135+
M : int
136+
HNSW M parameter (connectivity)
137+
ef_construction : int
138+
HNSW efConstruction parameter
139+
post_init_ef : int
140+
HNSW query-time ef parameter
141+
"""
129142
X = np.asarray(vectors, dtype=np.float32)
130143
if X.ndim != 2 or X.shape[1] != self.dim:
131144
raise ValueError(f"Expected vectors shape (N, {self.dim}), got {X.shape}")
132-
if not assume_normalized:
133-
n = np.linalg.norm(X, axis=1, keepdims=True)
134-
n = np.maximum(n, 1e-12)
135-
X = X / n
136145
ids = np.asarray(list(spec_ids), dtype=object)
137146
if ids.shape[0] != X.shape[0]:
138147
raise ValueError("spec_ids length must match number of vectors.")
@@ -146,7 +155,6 @@ def build_index(
146155
self._ids = ids
147156
self._meta = {
148157
"type": "ANNMS2DeepIndex",
149-
"assume_normalized": bool(assume_normalized),
150158
"M": M,
151159
"ef_construction": ef_construction,
152160
"post_init_ef": post_init_ef,
@@ -168,7 +176,6 @@ def build_index_from_sqlite(
168176
M: int = 16,
169177
ef_construction: int = 200,
170178
post_init_ef: int = 200,
171-
l2_normalize: bool = True,
172179
) -> int:
173180
"""
174181
Streams embeddings from SQLite and constructs an HNSW index in-place.
@@ -191,8 +198,6 @@ def build_index_from_sqlite(
191198
HNSW efConstruction parameter
192199
post_init_ef : int
193200
HNSW query-time ef parameter
194-
l2_normalize : bool
195-
Whether to L2-normalize vectors before indexing
196201
197202
Returns
198203
-------
@@ -244,12 +249,6 @@ def build_index_from_sqlite(
244249
if total == 0:
245250
raise ValueError(f"No embeddings loaded from {embeddings_table}.")
246251

247-
# Optional L2 normalization (in-place, cache-friendly)
248-
if l2_normalize:
249-
norms = np.linalg.norm(X, axis=1, keepdims=True)
250-
np.maximum(norms, 1e-12, out=norms)
251-
X /= norms
252-
253252
# Build the HNSW index with a SINGLE batch add
254253
index = nmslib.init(method='hnsw', space='cosinesimil', data_type=nmslib.DataType.DENSE_VECTOR)
255254
index.addDataPointBatch(X)
@@ -266,7 +265,6 @@ def build_index_from_sqlite(
266265
"M": M,
267266
"ef_construction": ef_construction,
268267
"post_init_ef": post_init_ef,
269-
"l2_normalize": bool(l2_normalize),
270268
}
271269
return int(total)
272270

@@ -302,22 +300,24 @@ def query(
302300
vector: np.ndarray,
303301
k: int = 10,
304302
ef: Optional[int] = None,
305-
assume_normalized: Optional[bool] = None
306303
) -> List[Tuple[str, float]]:
307304
"""Query the index with a single vector.
305+
306+
Parameters
307+
----------
308+
vector : np.ndarray
309+
1D array of shape (dim,) with float32 vector.
310+
k : int
311+
Number of nearest neighbors to return.
312+
ef : Optional[int]
313+
nmslib ef parameter (higher = better recall / slower).
308314
"""
309315
if self._index is None:
310316
raise RuntimeError("Index not built or loaded.")
311317
v = np.asarray(vector, dtype=np.float32).reshape(1, -1)
312318
if v.shape[1] != self.dim:
313319
raise ValueError(f"Query vector must have dim={self.dim}")
314320

315-
norm_flag = self._meta.get("assume_normalized", True) if assume_normalized is None else assume_normalized
316-
if not norm_flag:
317-
n = np.linalg.norm(v, axis=1, keepdims=True)
318-
n = np.maximum(n, 1e-12)
319-
v = v / n
320-
321321
if ef is not None:
322322
self._index.setQueryTimeParams({'ef': int(ef)})
323323

ms2query/library_io.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def create_new_library(
192192
M=params["M"],
193193
ef_construction=params["ef_construction"],
194194
post_init_ef=params["post_init_ef"],
195-
l2_normalize=True,
196195
)
197196
_print_progress(f"Indexed {n_vecs} embedding vectors.")
198197
emb_prefix = str(out_dir / _EMB_INDEX_BASENAME)

ms2query/ms2query_library.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def process_spectra(self, spectra: list[Spectrum]) -> List[Spectrum]:
6767
def compute_embeddings(self, spectra: list[Spectrum]) -> np.ndarray:
6868
"""
6969
Compute MS2DeepScore embeddings for arbitrary query spectra.
70+
71+
Spectra will be preprocessed via self.process_spectra(...) first.
7072
"""
7173
if not spectra:
7274
return np.empty((0, 0), dtype=np.float32)
@@ -88,7 +90,6 @@ def query_embedding_index(
8890
*,
8991
k: int = 10,
9092
ef: Optional[int] = None,
91-
assume_normalized: bool = True,
9293
return_dataframe: bool = True,
9394
) -> Union[List[List[Dict[str, Any]]], "pd.DataFrame"]:
9495
"""
@@ -106,8 +107,6 @@ def query_embedding_index(
106107
Top-k to return.
107108
ef : Optional[int]
108109
nmslib ef (higher = better recall / slower).
109-
assume_normalized : bool
110-
If False, will L2-normalize vectors again before query (normally keep True).
111110
return_dataframe : bool
112111
If True, returns a tidy DataFrame with columns:
113112
['query_ix','rank','spec_id','score']
@@ -126,7 +125,7 @@ def query_embedding_index(
126125
for qi in range(embeddings.shape[0]):
127126
# TODO: make faster by querying batch-wise
128127
# EmbeddingIndex.query returns list[(spec_id, similarity)]
129-
hits = self.embedding_index.query(embeddings[qi], k=k, ef=ef, assume_normalized=assume_normalized)
128+
hits = self.embedding_index.query(embeddings[qi], k=k, ef=ef)
130129
# convert to standard structure
131130
one = []
132131
for rk, (spec_id, score) in enumerate(hits, start=1):
@@ -143,6 +142,15 @@ def query_embedding_index(
143142
df = pd.DataFrame(rows, columns=["query_ix", "rank", "spec_id", "score"])
144143
return df
145144

145+
def query_compounds_by_spectra(
146+
self,
147+
spectra: Union[Spectrum, Sequence[Spectrum]],
148+
*,
149+
k: int = 10,
150+
ef: Optional[int] = None,
151+
return_dataframe: bool = True,
152+
):
153+
pass
146154
# ----------------------------- helpers / optional glue -----------------------------
147155

148156
def set_embedding_index(self, index: EmbeddingIndex) -> None:
@@ -170,7 +178,7 @@ def query_by_spec_ids(
170178

171179
results_all: List[List[Dict[str, Any]]] = []
172180
for qi in range(X.shape[0]):
173-
hits = self.embedding_index.query(X[qi], k=k, ef=ef, assume_normalized=True)
181+
hits = self.embedding_index.query(X[qi], k=k, ef=ef)
174182
one = [{"rank": rk + 1, "spec_id": sid, "score": float(score)} for rk, (sid, score) in enumerate(hits)]
175183
results_all.append(one)
176184

tests/test_ann_vector_index.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,30 @@ def _mk_unit_vecs(*rows):
1313
n = np.maximum(n, 1e-12)
1414
return X / n
1515

16+
1617
def test_build_index_and_query_dense():
1718
X = _mk_unit_vecs([1,0,0], [0,1,0], [0,0,1])
1819
ids = ["a","b","c"]
1920
idx = EmbeddingIndex(dim=3)
20-
idx.build_index(X, ids, assume_normalized=True)
21+
idx.build_index(X, ids)
2122
# Query close to [1,0,0]
2223
q = _mk_unit_vecs([0.9, 0.1, 0.0])[0]
2324
res = idx.query(q, k=2)
2425
assert [r[0] for r in res] == ["a", "b"]
2526
assert res[0][1] > res[1][1] # similarity desc
2627

28+
2729
def test_build_index_normalizes_when_requested():
2830
X = np.array([[2.0,0,0],[0,2.0,0]], dtype=np.float32)
2931
ids = ["x","y"]
3032
idx = EmbeddingIndex(dim=3)
31-
idx.build_index(X, ids, assume_normalized=False)
33+
idx.build_index(X, ids)
3234
q = np.array([1.0,0,0], dtype=np.float32)
33-
out = idx.query(q, k=1, assume_normalized=False)
35+
out = idx.query(q, k=1)
3436
assert out[0][0] == "x"
3537
assert 0.99 <= out[0][1] <= 1.0
3638

39+
3740
def test_query_errors_and_dim_check():
3841
idx = EmbeddingIndex(dim=3)
3942
with pytest.raises(RuntimeError):
@@ -42,6 +45,7 @@ def test_query_errors_and_dim_check():
4245
with pytest.raises(ValueError, match="dim=3"):
4346
idx.query(np.zeros(4, np.float32))
4447

48+
4549
def test_save_and_load_roundtrip_dense(tmp_path):
4650
X = _mk_unit_vecs([1,0,0],[0,1,0],[0,0,1])
4751
ids = ["a","b","c"]
@@ -62,6 +66,7 @@ def test_save_and_load_roundtrip_dense(tmp_path):
6266
meta = json.load(f)
6367
assert meta["space"] == "cosinesimil"
6468

69+
6570
@pytest.mark.parametrize("batch_rows", [1, 2, 3])
6671
def test_build_index_from_sqlite_streams_and_orders(batch_rows):
6772
conn = sqlite3.connect(":memory:")
@@ -75,12 +80,13 @@ def test_build_index_from_sqlite_streams_and_orders(batch_rows):
7580
],
7681
)
7782
idx = EmbeddingIndex(dim=3)
78-
n = idx.build_index_from_sqlite(conn, embeddings_table="embeddings", batch_rows=batch_rows, l2_normalize=True)
83+
n = idx.build_index_from_sqlite(conn, embeddings_table="embeddings", batch_rows=batch_rows)
7984
assert n == 3
8085
# Should be ordered by spec_id ascending ("id_1","id_2")
8186
out = idx.query(np.array([1.0, 0.0, 0.0], np.float32), k=2)
8287
assert [o[0] for o in out] == ["id_1", "id_2"]
8388

89+
8490
def test_build_index_from_sqlite_errors():
8591
conn = sqlite3.connect(":memory:")
8692
conn.execute("CREATE TABLE embeddings(spec_id TEXT, vec BLOB, d INTEGER)")
@@ -124,13 +130,15 @@ def test_tuples_to_csr_basic():
124130
r2 = csr[2].toarray().ravel()
125131
np.testing.assert_allclose(r2, [1.0, 0, 0, 0, 1.0])
126132

133+
127134
def test_tuples_to_csr_errors_when_index_out_of_bounds():
128135
tuples = [
129136
(np.array([0, 6], dtype=np.int32), np.array([1.0, 2.0], dtype=np.float32)),
130137
]
131138
with pytest.raises(ValueError, match=">= dim"):
132139
tuples_to_csr(tuples, dim=5)
133140

141+
134142
def test_csr_row_from_tuple_coalesces_and_validates():
135143
idxs = np.array([2, 2, 0], dtype=np.int32)
136144
vals = np.array([1.0, 2.0, 3.0], dtype=np.float32)
@@ -142,6 +150,7 @@ def test_csr_row_from_tuple_coalesces_and_validates():
142150
with pytest.raises(ValueError, match="Query index"):
143151
csr_row_from_tuple((np.array([5]), np.array([1.0], np.float32)), dim=5)
144152

153+
145154
def test_l1_norms_csr():
146155
X = sp.csr_matrix(
147156
np.array([[1.0, 2.0, 0.0], [0.0, 0.5, 0.5], [3.0, 0.0, 1.0]], dtype=np.float32)

0 commit comments

Comments
 (0)