diff --git a/skfp/fingerprints/map.py b/skfp/fingerprints/map.py index 676fef11..6f8a6039 100644 --- a/skfp/fingerprints/map.py +++ b/skfp/fingerprints/map.py @@ -9,7 +9,8 @@ from rdkit.Chem import Mol, MolToSmiles, PathToSubmol from rdkit.Chem.rdmolops import FindAtomEnvironmentOfRadiusN, GetDistanceMatrix from scipy.sparse import csr_array -from sklearn.utils._param_validation import Interval +from sklearn.utils._param_validation import Interval, StrOptions +from sklearn.utils.validation import check_random_state from skfp.bases import BaseFingerprintTransformer from skfp.utils import ensure_mols @@ -55,8 +56,11 @@ class MAPFingerprint(BaseFingerprintTransformer): Whether to include chirality information when computing atom types. This is also known as MAPC fingerprint [3]_ [4]_. - count : bool, default=False - Whether to return binary (bit) features, or their counts. + variant : {"binary", "count", "minhash"}, default="binary" + Output fingerprint variant: + - ``"binary"``: folded binary fingerprint, + - ``"count"``: folded count fingerprint, + - ``"minhash"``: MinHash sketch. sparse : bool, default=False Whether to return dense NumPy array, or sparse SciPy CSR array. @@ -129,15 +133,18 @@ class MAPFingerprint(BaseFingerprintTransformer): "radius": [Interval(Integral, 0, None, closed="left")], "include_duplicated_shingles": ["boolean"], "include_chirality": ["boolean"], + "variant": [StrOptions({"binary", "count", "minhash"})], } + _MINHASH_PRIME = np.uint64((1 << 61) - 1) + def __init__( self, - fp_size: int = 1024, + fp_size: int = 2048, radius: int = 2, include_duplicated_shingles: bool = False, include_chirality: bool = False, - count: bool = False, + variant: str = "binary", sparse: bool = False, n_jobs: int | None = None, batch_size: int | None = None, @@ -146,7 +153,7 @@ def __init__( ): super().__init__( n_features_out=fp_size, - count=count, + count=(variant == "count"), sparse=sparse, n_jobs=n_jobs, batch_size=batch_size, @@ -157,6 +164,7 @@ def __init__( self.radius = radius self.include_duplicated_shingles = include_duplicated_shingles self.include_chirality = include_chirality + self.variant = variant def transform( self, X: Sequence[str | Mol], copy: bool = False @@ -181,9 +189,15 @@ def transform( def _calculate_fingerprint(self, X: Sequence[str | Mol]) -> np.ndarray | csr_array: X = ensure_mols(X) + + if self.variant in {"minhash", "count"}: + dtype = np.uint32 + else: + dtype = np.uint8 + X = np.stack( [self._calculate_single_mol_fingerprint(mol) for mol in X], - dtype=np.uint32 if self.count else np.uint8, + dtype=dtype, ) return csr_array(X) if self.sparse else np.array(X) @@ -196,15 +210,12 @@ def _calculate_single_mol_fingerprint(self, mol: Mol) -> np.ndarray: atoms_envs = self._get_atom_envs(mol) shinglings = self._get_atom_pair_shingles(mol, atoms_envs) + hashed_shinglings = self._hash_shingles(shinglings) - folded = np.zeros(self.fp_size, dtype=np.uint32 if self.count else np.uint8) - for shingling in shinglings: - hashed = struct.unpack(" dict[int, list[str | None]]: from rdkit.Chem import FindMolChiralCenters @@ -300,3 +311,73 @@ def _make_shingle(env_a: str | None, env_b: str | None, distance: int) -> str: shingle = f"{smaller_env}|{distance}|{larger_env}" return shingle + + @staticmethod + def _hash_shingles(shinglings: set[bytes]) -> np.ndarray: + if not shinglings: + return np.empty(0, dtype=np.uint32) + + hashed_values = ( + struct.unpack(" np.ndarray: + folded = np.zeros( + self.fp_size, + dtype=np.uint32 if self.variant == "count" else np.uint8, + ) + + if hashed_shinglings.size == 0: + return folded + + indices = hashed_shinglings % self.fp_size + + if self.variant == "count": + np.add.at(folded, indices, 1) + else: + folded[indices] = 1 + + return folded + + def _minhash(self, hashed_shinglings: np.ndarray) -> np.ndarray: + # Return all-zero vector for empty shingle set + if hashed_shinglings.size == 0: + return np.zeros(self.fp_size, dtype=np.uint32) + + rng = np.random.default_rng(check_random_state(self.random_state)) + + # Generate permutation parameters: + # h_i(x) = (a_i * x + b_i) mod P + a = rng.integers( + 1, + self._MINHASH_PRIME, + size=self.fp_size, + dtype=np.uint64, + ) + b = rng.integers( + 0, + self._MINHASH_PRIME, + size=self.fp_size, + dtype=np.uint64, + ) + + x = hashed_shinglings.astype(np.uint64) + + # Apply all MinHash permutations to all hashed shingles at once. + # Broadcasting yields an array of shape (n_shingles, fp_size), where + # entry (j, i) is the value of permutation i applied to shingle j: + # h_i(x_j) = (a_i * x_j + b_i) mod P + permuted = ( + x[:, np.newaxis] * a[np.newaxis, :] + b[np.newaxis, :] + ) % self._MINHASH_PRIME + mins = permuted.min(axis=0) + + # Store the sketch as uint32 to keep output compact and consistent. + return mins.astype(np.uint32) diff --git a/skfp/model_selection/splitters/randomized_scaffold_split.py b/skfp/model_selection/splitters/randomized_scaffold_split.py index 828b7040..1ed8af2f 100644 --- a/skfp/model_selection/splitters/randomized_scaffold_split.py +++ b/skfp/model_selection/splitters/randomized_scaffold_split.py @@ -6,6 +6,7 @@ from numpy.random import Generator, RandomState from rdkit.Chem import Mol from sklearn.utils._param_validation import Interval, RealNotInt, validate_params +from sklearn.utils.validation import check_random_state from skfp.model_selection.splitters.scaffold_split import _create_scaffold_sets from skfp.model_selection.splitters.utils import ( @@ -141,11 +142,7 @@ def randomized_scaffold_train_test_split( ) scaffold_sets = _create_scaffold_sets(data, use_csk) - rng = ( - random_state - if isinstance(random_state, RandomState) - else np.random.default_rng(random_state) - ) + rng = np.random.default_rng(check_random_state(random_state)) rng.shuffle(scaffold_sets) train_idxs: list[int] = [] @@ -315,11 +312,7 @@ def randomized_scaffold_train_valid_test_split( ) scaffold_sets = _create_scaffold_sets(data, use_csk) - rng = ( - random_state - if isinstance(random_state, RandomState) - else np.random.default_rng(random_state) - ) + rng = np.random.default_rng(check_random_state(random_state)) rng.shuffle(scaffold_sets) train_idxs: list[int] = [] diff --git a/tests/fingerprints/map.py b/tests/fingerprints/map.py index 8fbca745..c9e7adce 100644 --- a/tests/fingerprints/map.py +++ b/tests/fingerprints/map.py @@ -24,7 +24,12 @@ def test_map_bit_fingerprint(smallest_smiles_list, smallest_mols_list): def test_map_count_fingerprint(smallest_smiles_list, smallest_mols_list): - map_fp = MAPFingerprint(verbose=0, n_jobs=-1) + map_fp = MAPFingerprint( + variant="count", + include_duplicated_shingles=True, + verbose=0, + n_jobs=-1, + ) X_skfp = map_fp.transform(smallest_smiles_list) X_map = np.stack( @@ -33,12 +38,12 @@ def test_map_count_fingerprint(smallest_smiles_list, smallest_mols_list): assert_equal(X_skfp, X_map) assert_equal(X_skfp.shape, (len(smallest_smiles_list), map_fp.fp_size)) - assert X_skfp.dtype == np.uint8 + assert X_skfp.dtype == np.uint32 assert np.all(X_skfp >= 0) def test_map_raw_hashes_fingerprint(smallest_smiles_list, smallest_mols_list): - map_fp = MAPFingerprint(n_jobs=-1) + map_fp = MAPFingerprint(variant="minhash", n_jobs=-1, random_state=0) X_skfp = map_fp.transform(smallest_smiles_list) X_map = np.stack( @@ -52,7 +57,7 @@ def test_map_raw_hashes_fingerprint(smallest_smiles_list, smallest_mols_list): def test_map_sparse_bit_fingerprint(smallest_smiles_list, smallest_mols_list): - map_fp = MAPFingerprint(sparse=True, n_jobs=-1) + map_fp = MAPFingerprint(variant="binary", sparse=True, n_jobs=-1) X_skfp = map_fp.transform(smallest_smiles_list) X_map = csr_array( @@ -69,7 +74,12 @@ def test_map_sparse_bit_fingerprint(smallest_smiles_list, smallest_mols_list): def test_map_sparse_count_fingerprint(smallest_smiles_list, smallest_mols_list): - map_fp = MAPFingerprint(include_duplicated_shingles=True, sparse=True, n_jobs=-1) + map_fp = MAPFingerprint( + variant="count", + include_duplicated_shingles=True, + sparse=True, + n_jobs=-1, + ) X_skfp = map_fp.transform(smallest_smiles_list) X_map = csr_array( @@ -78,38 +88,114 @@ def test_map_sparse_count_fingerprint(smallest_smiles_list, smallest_mols_list): assert_equal(X_skfp.data, X_map.data) assert_equal(X_skfp.shape, (len(smallest_smiles_list), map_fp.fp_size)) - assert X_skfp.dtype == np.uint8 + assert X_skfp.dtype == np.uint32 assert np.all(X_skfp.data > 0) - map_fp = MAPFingerprint( - include_duplicated_shingles=True, sparse=True, count=True, n_jobs=-1 - ) + +def test_map_sparse_raw_hashes_fingerprint(smallest_smiles_list, smallest_mols_list): + map_fp = MAPFingerprint(sparse=True, n_jobs=-1) X_skfp = map_fp.transform(smallest_smiles_list) X_map = csr_array( [map_fp._calculate_single_mol_fingerprint(mol) for mol in smallest_mols_list], + dtype=int, ) assert_equal(X_skfp.data, X_map.data) assert_equal(X_skfp.shape, (len(smallest_smiles_list), map_fp.fp_size)) - assert X_skfp.dtype == np.uint32 - assert np.all(X_skfp.data > 0) + assert np.issubdtype(X_skfp.dtype, np.integer) -def test_map_sparse_raw_hashes_fingerprint(smallest_smiles_list, smallest_mols_list): - map_fp = MAPFingerprint(sparse=True, n_jobs=-1) +def test_map_sparse_minhash_fingerprint(smallest_smiles_list, smallest_mols_list): + map_fp = MAPFingerprint( + variant="minhash", + sparse=True, + n_jobs=-1, + random_state=0, + ) X_skfp = map_fp.transform(smallest_smiles_list) X_map = csr_array( [map_fp._calculate_single_mol_fingerprint(mol) for mol in smallest_mols_list], - dtype=int, + dtype=np.uint32, ) assert_equal(X_skfp.data, X_map.data) assert_equal(X_skfp.shape, (len(smallest_smiles_list), map_fp.fp_size)) + assert X_skfp.dtype == np.uint32 assert np.issubdtype(X_skfp.dtype, np.integer) +def test_map_minhash_same_random_state_is_reproducible(smallest_smiles_list): + map_fp_1 = MAPFingerprint(variant="minhash", random_state=123, n_jobs=-1) + map_fp_2 = MAPFingerprint(variant="minhash", random_state=123, n_jobs=-1) + + X_1 = map_fp_1.transform(smallest_smiles_list) + X_2 = map_fp_2.transform(smallest_smiles_list) + + assert_equal(X_1, X_2) + + +def test_map_minhash_different_random_state_changes_output(smallest_smiles_list): + map_fp_1 = MAPFingerprint(variant="minhash", random_state=123, n_jobs=-1) + map_fp_2 = MAPFingerprint(variant="minhash", random_state=456, n_jobs=-1) + + X_1 = map_fp_1.transform(smallest_smiles_list) + X_2 = map_fp_2.transform(smallest_smiles_list) + + assert not np.array_equal(X_1, X_2) + + +def test_map_minhash_is_independent_of_input_order_and_batch_size(): + smiles = [ + "CC(=O)Oc1ccccc1C(=O)O", + "CCO", + "c1ccccc1", + "CCN(CC)CC", + ] + + map_fp = MAPFingerprint(variant="minhash", random_state=123, n_jobs=-1) + + X_full = map_fp.transform(smiles) + + # same molecules, different order + reordered_indices = [2, 0, 3, 1] + reordered_smiles = [smiles[i] for i in reordered_indices] + X_reordered = map_fp.transform(reordered_smiles) + + # compare molecule-by-molecule, not row-by-row + for original_idx, reordered_idx in enumerate(reordered_indices): + assert_equal(X_full[reordered_idx], X_reordered[original_idx]) + + # same molecules, smaller subsets / singleton calls + for i, smi in enumerate(smiles): + X_single = map_fp.transform([smi]) + assert_equal(X_full[i], X_single[0]) + + X_subset = map_fp.transform(smiles[:2]) + assert_equal(X_full[:2], X_subset) + + +def test_map_binary_ignores_random_state(smallest_smiles_list): + map_fp_1 = MAPFingerprint(variant="binary", random_state=123, n_jobs=-1) + map_fp_2 = MAPFingerprint(variant="binary", random_state=456, n_jobs=-1) + + X_1 = map_fp_1.transform(smallest_smiles_list) + X_2 = map_fp_2.transform(smallest_smiles_list) + + assert_equal(X_1, X_2) + + +def test_map_count_ignores_random_state(smallest_smiles_list): + map_fp_1 = MAPFingerprint(variant="count", random_state=123, n_jobs=-1) + map_fp_2 = MAPFingerprint(variant="count", random_state=456, n_jobs=-1) + + X_1 = map_fp_1.transform(smallest_smiles_list) + X_2 = map_fp_2.transform(smallest_smiles_list) + + assert_equal(X_1, X_2) + + def test_map_chirality(smallest_mols_list): # smoke test, this should not throw an error map_fp = MAPFingerprint(include_chirality=True, n_jobs=-1)