Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 96 additions & 15 deletions skfp/fingerprints/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from rdkit.Chem import Mol, MolToSmiles, PathToSubmol
from rdkit.Chem.rdmolops import FindAtomEnvironmentOfRadiusN, GetDistanceMatrix
from scipy.sparse import csr_array
from sklearn.utils._param_validation import Interval
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.validation import check_random_state

from skfp.bases import BaseFingerprintTransformer
from skfp.utils import ensure_mols
Expand Down Expand Up @@ -55,8 +56,11 @@ class MAPFingerprint(BaseFingerprintTransformer):
Whether to include chirality information when computing atom types. This is
also known as MAPC fingerprint [3]_ [4]_.

count : bool, default=False
Whether to return binary (bit) features, or their counts.
variant : {"binary", "count", "minhash"}, default="binary"
Output fingerprint variant:
- ``"binary"``: folded binary fingerprint,
- ``"count"``: folded count fingerprint,
- ``"minhash"``: MinHash sketch.

sparse : bool, default=False
Whether to return dense NumPy array, or sparse SciPy CSR array.
Expand Down Expand Up @@ -129,15 +133,18 @@ class MAPFingerprint(BaseFingerprintTransformer):
"radius": [Interval(Integral, 0, None, closed="left")],
"include_duplicated_shingles": ["boolean"],
"include_chirality": ["boolean"],
"variant": [StrOptions({"binary", "count", "minhash"})],
}

_MINHASH_PRIME = np.uint64((1 << 61) - 1)

def __init__(
self,
fp_size: int = 1024,
fp_size: int = 2048,
radius: int = 2,
include_duplicated_shingles: bool = False,
include_chirality: bool = False,
count: bool = False,
variant: str = "binary",
sparse: bool = False,
n_jobs: int | None = None,
batch_size: int | None = None,
Expand All @@ -146,7 +153,7 @@ def __init__(
):
super().__init__(
n_features_out=fp_size,
count=count,
count=(variant == "count"),
sparse=sparse,
n_jobs=n_jobs,
batch_size=batch_size,
Expand All @@ -157,6 +164,7 @@ def __init__(
self.radius = radius
self.include_duplicated_shingles = include_duplicated_shingles
self.include_chirality = include_chirality
self.variant = variant

def transform(
self, X: Sequence[str | Mol], copy: bool = False
Expand All @@ -181,9 +189,15 @@ def transform(

def _calculate_fingerprint(self, X: Sequence[str | Mol]) -> np.ndarray | csr_array:
X = ensure_mols(X)

if self.variant in {"minhash", "count"}:
dtype = np.uint32
else:
dtype = np.uint8

X = np.stack(
[self._calculate_single_mol_fingerprint(mol) for mol in X],
dtype=np.uint32 if self.count else np.uint8,
dtype=dtype,
)

return csr_array(X) if self.sparse else np.array(X)
Expand All @@ -196,15 +210,12 @@ def _calculate_single_mol_fingerprint(self, mol: Mol) -> np.ndarray:

atoms_envs = self._get_atom_envs(mol)
shinglings = self._get_atom_pair_shingles(mol, atoms_envs)
hashed_shinglings = self._hash_shingles(shinglings)

folded = np.zeros(self.fp_size, dtype=np.uint32 if self.count else np.uint8)
for shingling in shinglings:
hashed = struct.unpack("<I", sha256(shingling).digest()[:4])[0]
if self.count:
folded[hashed % self.fp_size] += 1
else:
folded[hashed % self.fp_size] = 1
return folded
if self.variant == "minhash":
return self._minhash(hashed_shinglings)

return self._fold(hashed_shinglings)

def _get_atom_envs(self, mol: Mol) -> dict[int, list[str | None]]:
from rdkit.Chem import FindMolChiralCenters
Expand Down Expand Up @@ -300,3 +311,73 @@ def _make_shingle(env_a: str | None, env_b: str | None, distance: int) -> str:

shingle = f"{smaller_env}|{distance}|{larger_env}"
return shingle

@staticmethod
def _hash_shingles(shinglings: set[bytes]) -> np.ndarray:
if not shinglings:
return np.empty(0, dtype=np.uint32)

hashed_values = (
struct.unpack("<I", sha256(shingling).digest()[:4])[0]
for shingling in shinglings
)

return np.fromiter(
hashed_values,
dtype=np.uint32,
count=len(shinglings),
)

def _fold(self, hashed_shinglings: np.ndarray) -> np.ndarray:
folded = np.zeros(
self.fp_size,
dtype=np.uint32 if self.variant == "count" else np.uint8,
)

if hashed_shinglings.size == 0:
return folded

indices = hashed_shinglings % self.fp_size

if self.variant == "count":
np.add.at(folded, indices, 1)
else:
folded[indices] = 1

return folded

def _minhash(self, hashed_shinglings: np.ndarray) -> np.ndarray:
# Return all-zero vector for empty shingle set
if hashed_shinglings.size == 0:
return np.zeros(self.fp_size, dtype=np.uint32)

rng = np.random.default_rng(check_random_state(self.random_state))

# Generate permutation parameters:
# h_i(x) = (a_i * x + b_i) mod P
a = rng.integers(
1,
self._MINHASH_PRIME,
size=self.fp_size,
dtype=np.uint64,
)
b = rng.integers(
0,
self._MINHASH_PRIME,
size=self.fp_size,
dtype=np.uint64,
)

x = hashed_shinglings.astype(np.uint64)

# Apply all MinHash permutations to all hashed shingles at once.
# Broadcasting yields an array of shape (n_shingles, fp_size), where
# entry (j, i) is the value of permutation i applied to shingle j:
# h_i(x_j) = (a_i * x_j + b_i) mod P
permuted = (
x[:, np.newaxis] * a[np.newaxis, :] + b[np.newaxis, :]
) % self._MINHASH_PRIME
mins = permuted.min(axis=0)

# Store the sketch as uint32 to keep output compact and consistent.
return mins.astype(np.uint32)
13 changes: 3 additions & 10 deletions skfp/model_selection/splitters/randomized_scaffold_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from numpy.random import Generator, RandomState
from rdkit.Chem import Mol
from sklearn.utils._param_validation import Interval, RealNotInt, validate_params
from sklearn.utils.validation import check_random_state

from skfp.model_selection.splitters.scaffold_split import _create_scaffold_sets
from skfp.model_selection.splitters.utils import (
Expand Down Expand Up @@ -141,11 +142,7 @@ def randomized_scaffold_train_test_split(
)

scaffold_sets = _create_scaffold_sets(data, use_csk)
rng = (
random_state
if isinstance(random_state, RandomState)
else np.random.default_rng(random_state)
)
rng = np.random.default_rng(check_random_state(random_state))
rng.shuffle(scaffold_sets)

train_idxs: list[int] = []
Expand Down Expand Up @@ -315,11 +312,7 @@ def randomized_scaffold_train_valid_test_split(
)

scaffold_sets = _create_scaffold_sets(data, use_csk)
rng = (
random_state
if isinstance(random_state, RandomState)
else np.random.default_rng(random_state)
)
rng = np.random.default_rng(check_random_state(random_state))
rng.shuffle(scaffold_sets)

train_idxs: list[int] = []
Expand Down
114 changes: 100 additions & 14 deletions tests/fingerprints/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ def test_map_bit_fingerprint(smallest_smiles_list, smallest_mols_list):


def test_map_count_fingerprint(smallest_smiles_list, smallest_mols_list):
map_fp = MAPFingerprint(verbose=0, n_jobs=-1)
map_fp = MAPFingerprint(
variant="count",
include_duplicated_shingles=True,
verbose=0,
n_jobs=-1,
)
X_skfp = map_fp.transform(smallest_smiles_list)

X_map = np.stack(
Expand All @@ -33,12 +38,12 @@ def test_map_count_fingerprint(smallest_smiles_list, smallest_mols_list):

assert_equal(X_skfp, X_map)
assert_equal(X_skfp.shape, (len(smallest_smiles_list), map_fp.fp_size))
assert X_skfp.dtype == np.uint8
assert X_skfp.dtype == np.uint32
assert np.all(X_skfp >= 0)


def test_map_raw_hashes_fingerprint(smallest_smiles_list, smallest_mols_list):
map_fp = MAPFingerprint(n_jobs=-1)
map_fp = MAPFingerprint(variant="minhash", n_jobs=-1, random_state=0)
X_skfp = map_fp.transform(smallest_smiles_list)

X_map = np.stack(
Expand All @@ -52,7 +57,7 @@ def test_map_raw_hashes_fingerprint(smallest_smiles_list, smallest_mols_list):


def test_map_sparse_bit_fingerprint(smallest_smiles_list, smallest_mols_list):
map_fp = MAPFingerprint(sparse=True, n_jobs=-1)
map_fp = MAPFingerprint(variant="binary", sparse=True, n_jobs=-1)
X_skfp = map_fp.transform(smallest_smiles_list)

X_map = csr_array(
Expand All @@ -69,7 +74,12 @@ def test_map_sparse_bit_fingerprint(smallest_smiles_list, smallest_mols_list):


def test_map_sparse_count_fingerprint(smallest_smiles_list, smallest_mols_list):
map_fp = MAPFingerprint(include_duplicated_shingles=True, sparse=True, n_jobs=-1)
map_fp = MAPFingerprint(
variant="count",
include_duplicated_shingles=True,
sparse=True,
n_jobs=-1,
)
X_skfp = map_fp.transform(smallest_smiles_list)

X_map = csr_array(
Expand All @@ -78,38 +88,114 @@ def test_map_sparse_count_fingerprint(smallest_smiles_list, smallest_mols_list):

assert_equal(X_skfp.data, X_map.data)
assert_equal(X_skfp.shape, (len(smallest_smiles_list), map_fp.fp_size))
assert X_skfp.dtype == np.uint8
assert X_skfp.dtype == np.uint32
assert np.all(X_skfp.data > 0)

map_fp = MAPFingerprint(
include_duplicated_shingles=True, sparse=True, count=True, n_jobs=-1
)

def test_map_sparse_raw_hashes_fingerprint(smallest_smiles_list, smallest_mols_list):
map_fp = MAPFingerprint(sparse=True, n_jobs=-1)
X_skfp = map_fp.transform(smallest_smiles_list)

X_map = csr_array(
[map_fp._calculate_single_mol_fingerprint(mol) for mol in smallest_mols_list],
dtype=int,
)

assert_equal(X_skfp.data, X_map.data)
assert_equal(X_skfp.shape, (len(smallest_smiles_list), map_fp.fp_size))
assert X_skfp.dtype == np.uint32
assert np.all(X_skfp.data > 0)
assert np.issubdtype(X_skfp.dtype, np.integer)


def test_map_sparse_raw_hashes_fingerprint(smallest_smiles_list, smallest_mols_list):
map_fp = MAPFingerprint(sparse=True, n_jobs=-1)
def test_map_sparse_minhash_fingerprint(smallest_smiles_list, smallest_mols_list):
map_fp = MAPFingerprint(
variant="minhash",
sparse=True,
n_jobs=-1,
random_state=0,
)
X_skfp = map_fp.transform(smallest_smiles_list)

X_map = csr_array(
[map_fp._calculate_single_mol_fingerprint(mol) for mol in smallest_mols_list],
dtype=int,
dtype=np.uint32,
)

assert_equal(X_skfp.data, X_map.data)
assert_equal(X_skfp.shape, (len(smallest_smiles_list), map_fp.fp_size))
assert X_skfp.dtype == np.uint32
assert np.issubdtype(X_skfp.dtype, np.integer)


def test_map_minhash_same_random_state_is_reproducible(smallest_smiles_list):
map_fp_1 = MAPFingerprint(variant="minhash", random_state=123, n_jobs=-1)
map_fp_2 = MAPFingerprint(variant="minhash", random_state=123, n_jobs=-1)

X_1 = map_fp_1.transform(smallest_smiles_list)
X_2 = map_fp_2.transform(smallest_smiles_list)

assert_equal(X_1, X_2)


def test_map_minhash_different_random_state_changes_output(smallest_smiles_list):
map_fp_1 = MAPFingerprint(variant="minhash", random_state=123, n_jobs=-1)
map_fp_2 = MAPFingerprint(variant="minhash", random_state=456, n_jobs=-1)

X_1 = map_fp_1.transform(smallest_smiles_list)
X_2 = map_fp_2.transform(smallest_smiles_list)

assert not np.array_equal(X_1, X_2)
Comment on lines +129 to +146
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR!
It would be nice to also include a test that makes sure that given molecules for the same random_state are hashed in the same way, regardless of their order and size of the list passed to the .transform() method.
This will save us from a potential problems in case someone makes modifications to the ._minhash() method in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test for this (if that is what you meant).



def test_map_minhash_is_independent_of_input_order_and_batch_size():
smiles = [
"CC(=O)Oc1ccccc1C(=O)O",
"CCO",
"c1ccccc1",
"CCN(CC)CC",
]

map_fp = MAPFingerprint(variant="minhash", random_state=123, n_jobs=-1)

X_full = map_fp.transform(smiles)

# same molecules, different order
reordered_indices = [2, 0, 3, 1]
reordered_smiles = [smiles[i] for i in reordered_indices]
X_reordered = map_fp.transform(reordered_smiles)

# compare molecule-by-molecule, not row-by-row
for original_idx, reordered_idx in enumerate(reordered_indices):
assert_equal(X_full[reordered_idx], X_reordered[original_idx])

# same molecules, smaller subsets / singleton calls
for i, smi in enumerate(smiles):
X_single = map_fp.transform([smi])
assert_equal(X_full[i], X_single[0])

X_subset = map_fp.transform(smiles[:2])
assert_equal(X_full[:2], X_subset)


def test_map_binary_ignores_random_state(smallest_smiles_list):
map_fp_1 = MAPFingerprint(variant="binary", random_state=123, n_jobs=-1)
map_fp_2 = MAPFingerprint(variant="binary", random_state=456, n_jobs=-1)

X_1 = map_fp_1.transform(smallest_smiles_list)
X_2 = map_fp_2.transform(smallest_smiles_list)

assert_equal(X_1, X_2)


def test_map_count_ignores_random_state(smallest_smiles_list):
map_fp_1 = MAPFingerprint(variant="count", random_state=123, n_jobs=-1)
map_fp_2 = MAPFingerprint(variant="count", random_state=456, n_jobs=-1)

X_1 = map_fp_1.transform(smallest_smiles_list)
X_2 = map_fp_2.transform(smallest_smiles_list)

assert_equal(X_1, X_2)


def test_map_chirality(smallest_mols_list):
# smoke test, this should not throw an error
map_fp = MAPFingerprint(include_chirality=True, n_jobs=-1)
Expand Down