Skip to content

Commit e4895ed

Browse files
MAP4 implementation fixes (#528)
1 parent 6c8b6b3 commit e4895ed

File tree

3 files changed

+199
-39
lines changed

3 files changed

+199
-39
lines changed

skfp/fingerprints/map.py

Lines changed: 96 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from rdkit.Chem import Mol, MolToSmiles, PathToSubmol
1010
from rdkit.Chem.rdmolops import FindAtomEnvironmentOfRadiusN, GetDistanceMatrix
1111
from scipy.sparse import csr_array
12-
from sklearn.utils._param_validation import Interval
12+
from sklearn.utils._param_validation import Interval, StrOptions
13+
from sklearn.utils.validation import check_random_state
1314

1415
from skfp.bases import BaseFingerprintTransformer
1516
from skfp.utils import ensure_mols
@@ -55,8 +56,11 @@ class MAPFingerprint(BaseFingerprintTransformer):
5556
Whether to include chirality information when computing atom types. This is
5657
also known as MAPC fingerprint [3]_ [4]_.
5758
58-
count : bool, default=False
59-
Whether to return binary (bit) features, or their counts.
59+
variant : {"binary", "count", "minhash"}, default="binary"
60+
Output fingerprint variant:
61+
- ``"binary"``: folded binary fingerprint,
62+
- ``"count"``: folded count fingerprint,
63+
- ``"minhash"``: MinHash sketch.
6064
6165
sparse : bool, default=False
6266
Whether to return dense NumPy array, or sparse SciPy CSR array.
@@ -129,15 +133,18 @@ class MAPFingerprint(BaseFingerprintTransformer):
129133
"radius": [Interval(Integral, 0, None, closed="left")],
130134
"include_duplicated_shingles": ["boolean"],
131135
"include_chirality": ["boolean"],
136+
"variant": [StrOptions({"binary", "count", "minhash"})],
132137
}
133138

139+
_MINHASH_PRIME = np.uint64((1 << 61) - 1)
140+
134141
def __init__(
135142
self,
136-
fp_size: int = 1024,
143+
fp_size: int = 2048,
137144
radius: int = 2,
138145
include_duplicated_shingles: bool = False,
139146
include_chirality: bool = False,
140-
count: bool = False,
147+
variant: str = "binary",
141148
sparse: bool = False,
142149
n_jobs: int | None = None,
143150
batch_size: int | None = None,
@@ -146,7 +153,7 @@ def __init__(
146153
):
147154
super().__init__(
148155
n_features_out=fp_size,
149-
count=count,
156+
count=(variant == "count"),
150157
sparse=sparse,
151158
n_jobs=n_jobs,
152159
batch_size=batch_size,
@@ -157,6 +164,7 @@ def __init__(
157164
self.radius = radius
158165
self.include_duplicated_shingles = include_duplicated_shingles
159166
self.include_chirality = include_chirality
167+
self.variant = variant
160168

161169
def transform(
162170
self, X: Sequence[str | Mol], copy: bool = False
@@ -181,9 +189,15 @@ def transform(
181189

182190
def _calculate_fingerprint(self, X: Sequence[str | Mol]) -> np.ndarray | csr_array:
183191
X = ensure_mols(X)
192+
193+
if self.variant in {"minhash", "count"}:
194+
dtype = np.uint32
195+
else:
196+
dtype = np.uint8
197+
184198
X = np.stack(
185199
[self._calculate_single_mol_fingerprint(mol) for mol in X],
186-
dtype=np.uint32 if self.count else np.uint8,
200+
dtype=dtype,
187201
)
188202

189203
return csr_array(X) if self.sparse else np.array(X)
@@ -196,15 +210,12 @@ def _calculate_single_mol_fingerprint(self, mol: Mol) -> np.ndarray:
196210

197211
atoms_envs = self._get_atom_envs(mol)
198212
shinglings = self._get_atom_pair_shingles(mol, atoms_envs)
213+
hashed_shinglings = self._hash_shingles(shinglings)
199214

200-
folded = np.zeros(self.fp_size, dtype=np.uint32 if self.count else np.uint8)
201-
for shingling in shinglings:
202-
hashed = struct.unpack("<I", sha256(shingling).digest()[:4])[0]
203-
if self.count:
204-
folded[hashed % self.fp_size] += 1
205-
else:
206-
folded[hashed % self.fp_size] = 1
207-
return folded
215+
if self.variant == "minhash":
216+
return self._minhash(hashed_shinglings)
217+
218+
return self._fold(hashed_shinglings)
208219

209220
def _get_atom_envs(self, mol: Mol) -> dict[int, list[str | None]]:
210221
from rdkit.Chem import FindMolChiralCenters
@@ -300,3 +311,73 @@ def _make_shingle(env_a: str | None, env_b: str | None, distance: int) -> str:
300311

301312
shingle = f"{smaller_env}|{distance}|{larger_env}"
302313
return shingle
314+
315+
@staticmethod
316+
def _hash_shingles(shinglings: set[bytes]) -> np.ndarray:
317+
if not shinglings:
318+
return np.empty(0, dtype=np.uint32)
319+
320+
hashed_values = (
321+
struct.unpack("<I", sha256(shingling).digest()[:4])[0]
322+
for shingling in shinglings
323+
)
324+
325+
return np.fromiter(
326+
hashed_values,
327+
dtype=np.uint32,
328+
count=len(shinglings),
329+
)
330+
331+
def _fold(self, hashed_shinglings: np.ndarray) -> np.ndarray:
332+
folded = np.zeros(
333+
self.fp_size,
334+
dtype=np.uint32 if self.variant == "count" else np.uint8,
335+
)
336+
337+
if hashed_shinglings.size == 0:
338+
return folded
339+
340+
indices = hashed_shinglings % self.fp_size
341+
342+
if self.variant == "count":
343+
np.add.at(folded, indices, 1)
344+
else:
345+
folded[indices] = 1
346+
347+
return folded
348+
349+
def _minhash(self, hashed_shinglings: np.ndarray) -> np.ndarray:
350+
# Return all-zero vector for empty shingle set
351+
if hashed_shinglings.size == 0:
352+
return np.zeros(self.fp_size, dtype=np.uint32)
353+
354+
rng = np.random.default_rng(check_random_state(self.random_state))
355+
356+
# Generate permutation parameters:
357+
# h_i(x) = (a_i * x + b_i) mod P
358+
a = rng.integers(
359+
1,
360+
self._MINHASH_PRIME,
361+
size=self.fp_size,
362+
dtype=np.uint64,
363+
)
364+
b = rng.integers(
365+
0,
366+
self._MINHASH_PRIME,
367+
size=self.fp_size,
368+
dtype=np.uint64,
369+
)
370+
371+
x = hashed_shinglings.astype(np.uint64)
372+
373+
# Apply all MinHash permutations to all hashed shingles at once.
374+
# Broadcasting yields an array of shape (n_shingles, fp_size), where
375+
# entry (j, i) is the value of permutation i applied to shingle j:
376+
# h_i(x_j) = (a_i * x_j + b_i) mod P
377+
permuted = (
378+
x[:, np.newaxis] * a[np.newaxis, :] + b[np.newaxis, :]
379+
) % self._MINHASH_PRIME
380+
mins = permuted.min(axis=0)
381+
382+
# Store the sketch as uint32 to keep output compact and consistent.
383+
return mins.astype(np.uint32)

skfp/model_selection/splitters/randomized_scaffold_split.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from numpy.random import Generator, RandomState
77
from rdkit.Chem import Mol
88
from sklearn.utils._param_validation import Interval, RealNotInt, validate_params
9+
from sklearn.utils.validation import check_random_state
910

1011
from skfp.model_selection.splitters.scaffold_split import _create_scaffold_sets
1112
from skfp.model_selection.splitters.utils import (
@@ -141,11 +142,7 @@ def randomized_scaffold_train_test_split(
141142
)
142143

143144
scaffold_sets = _create_scaffold_sets(data, use_csk)
144-
rng = (
145-
random_state
146-
if isinstance(random_state, RandomState)
147-
else np.random.default_rng(random_state)
148-
)
145+
rng = np.random.default_rng(check_random_state(random_state))
149146
rng.shuffle(scaffold_sets)
150147

151148
train_idxs: list[int] = []
@@ -315,11 +312,7 @@ def randomized_scaffold_train_valid_test_split(
315312
)
316313

317314
scaffold_sets = _create_scaffold_sets(data, use_csk)
318-
rng = (
319-
random_state
320-
if isinstance(random_state, RandomState)
321-
else np.random.default_rng(random_state)
322-
)
315+
rng = np.random.default_rng(check_random_state(random_state))
323316
rng.shuffle(scaffold_sets)
324317

325318
train_idxs: list[int] = []

tests/fingerprints/map.py

Lines changed: 100 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ def test_map_bit_fingerprint(smallest_smiles_list, smallest_mols_list):
2424

2525

2626
def test_map_count_fingerprint(smallest_smiles_list, smallest_mols_list):
27-
map_fp = MAPFingerprint(verbose=0, n_jobs=-1)
27+
map_fp = MAPFingerprint(
28+
variant="count",
29+
include_duplicated_shingles=True,
30+
verbose=0,
31+
n_jobs=-1,
32+
)
2833
X_skfp = map_fp.transform(smallest_smiles_list)
2934

3035
X_map = np.stack(
@@ -33,12 +38,12 @@ def test_map_count_fingerprint(smallest_smiles_list, smallest_mols_list):
3338

3439
assert_equal(X_skfp, X_map)
3540
assert_equal(X_skfp.shape, (len(smallest_smiles_list), map_fp.fp_size))
36-
assert X_skfp.dtype == np.uint8
41+
assert X_skfp.dtype == np.uint32
3742
assert np.all(X_skfp >= 0)
3843

3944

4045
def test_map_raw_hashes_fingerprint(smallest_smiles_list, smallest_mols_list):
41-
map_fp = MAPFingerprint(n_jobs=-1)
46+
map_fp = MAPFingerprint(variant="minhash", n_jobs=-1, random_state=0)
4247
X_skfp = map_fp.transform(smallest_smiles_list)
4348

4449
X_map = np.stack(
@@ -52,7 +57,7 @@ def test_map_raw_hashes_fingerprint(smallest_smiles_list, smallest_mols_list):
5257

5358

5459
def test_map_sparse_bit_fingerprint(smallest_smiles_list, smallest_mols_list):
55-
map_fp = MAPFingerprint(sparse=True, n_jobs=-1)
60+
map_fp = MAPFingerprint(variant="binary", sparse=True, n_jobs=-1)
5661
X_skfp = map_fp.transform(smallest_smiles_list)
5762

5863
X_map = csr_array(
@@ -69,7 +74,12 @@ def test_map_sparse_bit_fingerprint(smallest_smiles_list, smallest_mols_list):
6974

7075

7176
def test_map_sparse_count_fingerprint(smallest_smiles_list, smallest_mols_list):
72-
map_fp = MAPFingerprint(include_duplicated_shingles=True, sparse=True, n_jobs=-1)
77+
map_fp = MAPFingerprint(
78+
variant="count",
79+
include_duplicated_shingles=True,
80+
sparse=True,
81+
n_jobs=-1,
82+
)
7383
X_skfp = map_fp.transform(smallest_smiles_list)
7484

7585
X_map = csr_array(
@@ -78,38 +88,114 @@ def test_map_sparse_count_fingerprint(smallest_smiles_list, smallest_mols_list):
7888

7989
assert_equal(X_skfp.data, X_map.data)
8090
assert_equal(X_skfp.shape, (len(smallest_smiles_list), map_fp.fp_size))
81-
assert X_skfp.dtype == np.uint8
91+
assert X_skfp.dtype == np.uint32
8292
assert np.all(X_skfp.data > 0)
8393

84-
map_fp = MAPFingerprint(
85-
include_duplicated_shingles=True, sparse=True, count=True, n_jobs=-1
86-
)
94+
95+
def test_map_sparse_raw_hashes_fingerprint(smallest_smiles_list, smallest_mols_list):
96+
map_fp = MAPFingerprint(sparse=True, n_jobs=-1)
8797
X_skfp = map_fp.transform(smallest_smiles_list)
8898

8999
X_map = csr_array(
90100
[map_fp._calculate_single_mol_fingerprint(mol) for mol in smallest_mols_list],
101+
dtype=int,
91102
)
92103

93104
assert_equal(X_skfp.data, X_map.data)
94105
assert_equal(X_skfp.shape, (len(smallest_smiles_list), map_fp.fp_size))
95-
assert X_skfp.dtype == np.uint32
96-
assert np.all(X_skfp.data > 0)
106+
assert np.issubdtype(X_skfp.dtype, np.integer)
97107

98108

99-
def test_map_sparse_raw_hashes_fingerprint(smallest_smiles_list, smallest_mols_list):
100-
map_fp = MAPFingerprint(sparse=True, n_jobs=-1)
109+
def test_map_sparse_minhash_fingerprint(smallest_smiles_list, smallest_mols_list):
110+
map_fp = MAPFingerprint(
111+
variant="minhash",
112+
sparse=True,
113+
n_jobs=-1,
114+
random_state=0,
115+
)
101116
X_skfp = map_fp.transform(smallest_smiles_list)
102117

103118
X_map = csr_array(
104119
[map_fp._calculate_single_mol_fingerprint(mol) for mol in smallest_mols_list],
105-
dtype=int,
120+
dtype=np.uint32,
106121
)
107122

108123
assert_equal(X_skfp.data, X_map.data)
109124
assert_equal(X_skfp.shape, (len(smallest_smiles_list), map_fp.fp_size))
125+
assert X_skfp.dtype == np.uint32
110126
assert np.issubdtype(X_skfp.dtype, np.integer)
111127

112128

129+
def test_map_minhash_same_random_state_is_reproducible(smallest_smiles_list):
130+
map_fp_1 = MAPFingerprint(variant="minhash", random_state=123, n_jobs=-1)
131+
map_fp_2 = MAPFingerprint(variant="minhash", random_state=123, n_jobs=-1)
132+
133+
X_1 = map_fp_1.transform(smallest_smiles_list)
134+
X_2 = map_fp_2.transform(smallest_smiles_list)
135+
136+
assert_equal(X_1, X_2)
137+
138+
139+
def test_map_minhash_different_random_state_changes_output(smallest_smiles_list):
140+
map_fp_1 = MAPFingerprint(variant="minhash", random_state=123, n_jobs=-1)
141+
map_fp_2 = MAPFingerprint(variant="minhash", random_state=456, n_jobs=-1)
142+
143+
X_1 = map_fp_1.transform(smallest_smiles_list)
144+
X_2 = map_fp_2.transform(smallest_smiles_list)
145+
146+
assert not np.array_equal(X_1, X_2)
147+
148+
149+
def test_map_minhash_is_independent_of_input_order_and_batch_size():
150+
smiles = [
151+
"CC(=O)Oc1ccccc1C(=O)O",
152+
"CCO",
153+
"c1ccccc1",
154+
"CCN(CC)CC",
155+
]
156+
157+
map_fp = MAPFingerprint(variant="minhash", random_state=123, n_jobs=-1)
158+
159+
X_full = map_fp.transform(smiles)
160+
161+
# same molecules, different order
162+
reordered_indices = [2, 0, 3, 1]
163+
reordered_smiles = [smiles[i] for i in reordered_indices]
164+
X_reordered = map_fp.transform(reordered_smiles)
165+
166+
# compare molecule-by-molecule, not row-by-row
167+
for original_idx, reordered_idx in enumerate(reordered_indices):
168+
assert_equal(X_full[reordered_idx], X_reordered[original_idx])
169+
170+
# same molecules, smaller subsets / singleton calls
171+
for i, smi in enumerate(smiles):
172+
X_single = map_fp.transform([smi])
173+
assert_equal(X_full[i], X_single[0])
174+
175+
X_subset = map_fp.transform(smiles[:2])
176+
assert_equal(X_full[:2], X_subset)
177+
178+
179+
def test_map_binary_ignores_random_state(smallest_smiles_list):
180+
map_fp_1 = MAPFingerprint(variant="binary", random_state=123, n_jobs=-1)
181+
map_fp_2 = MAPFingerprint(variant="binary", random_state=456, n_jobs=-1)
182+
183+
X_1 = map_fp_1.transform(smallest_smiles_list)
184+
X_2 = map_fp_2.transform(smallest_smiles_list)
185+
186+
assert_equal(X_1, X_2)
187+
188+
189+
def test_map_count_ignores_random_state(smallest_smiles_list):
190+
map_fp_1 = MAPFingerprint(variant="count", random_state=123, n_jobs=-1)
191+
map_fp_2 = MAPFingerprint(variant="count", random_state=456, n_jobs=-1)
192+
193+
X_1 = map_fp_1.transform(smallest_smiles_list)
194+
X_2 = map_fp_2.transform(smallest_smiles_list)
195+
196+
assert_equal(X_1, X_2)
197+
198+
113199
def test_map_chirality(smallest_mols_list):
114200
# smoke test, this should not throw an error
115201
map_fp = MAPFingerprint(include_chirality=True, n_jobs=-1)

0 commit comments

Comments
 (0)