Skip to content

Commit 64560bf

Browse files
Klekota-Roth prefix optimization (#470)
1 parent ad559ba commit 64560bf

File tree

8 files changed

+5192
-4993
lines changed

8 files changed

+5192
-4993
lines changed

skfp/fingerprints/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .functional_groups import FunctionalGroupsFingerprint
1111
from .getaway import GETAWAYFingerprint
1212
from .ghose_crippen import GhoseCrippenFingerprint
13-
from .klekota_roth import KlekotaRothFingerprint
13+
from .klekota_roth.klekota_roth import KlekotaRothFingerprint
1414
from .laggner import LaggnerFingerprint
1515
from .layered import LayeredFingerprint
1616
from .lingo import LingoFingerprint

skfp/fingerprints/klekota_roth.py

Lines changed: 0 additions & 4992 deletions
This file was deleted.

skfp/fingerprints/klekota_roth/__init__.py

Whitespace-only changes.
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
from collections import deque
2+
from collections.abc import Sequence
3+
4+
import numpy as np
5+
from rdkit.Chem import Mol
6+
from scipy.sparse import csr_array
7+
8+
from skfp.bases import BaseSubstructureFingerprint
9+
from skfp.utils import ensure_mols
10+
11+
from .smarts_tree import PatternNode, _load_tree
12+
13+
14+
class KlekotaRothFingerprint(BaseSubstructureFingerprint):
15+
"""
16+
Klekota-Roth fingerprint.
17+
18+
A substructure fingerprint based on [1]_, with implementation based on CDK [2]_.
19+
Tests for presence of 4860 predefined substructures which are predisposed for
20+
bioactivity.
21+
22+
Parameters
23+
----------
24+
count : bool, default=False
25+
Whether to return binary (bit) features, or their counts.
26+
27+
sparse : bool, default=False
28+
Whether to return dense NumPy array, or sparse SciPy CSR array.
29+
30+
n_jobs : int, default=None
31+
The number of jobs to run in parallel. :meth:`transform` is parallelized
32+
over the input molecules. ``None`` means 1 unless in a
33+
:obj:`joblib.parallel_backend` context. ``-1`` means using all processors.
34+
See scikit-learn documentation on ``n_jobs`` for more details.
35+
36+
batch_size : int, default=None
37+
Number of inputs processed in each batch. ``None`` divides input data into
38+
equal-sized parts, as many as ``n_jobs``.
39+
40+
verbose : int or dict, default=0
41+
Controls the verbosity when computing fingerprints.
42+
If a dictionary is passed, it is treated as kwargs for ``tqdm()``,
43+
and can be used to control the progress bar.
44+
45+
Attributes
46+
----------
47+
n_features_out : int = 4860
48+
Number of output features, size of fingerprints.
49+
50+
requires_conformers : bool = False
51+
This fingerprint uses only 2D molecular graphs and does not require conformers.
52+
53+
References
54+
----------
55+
.. [1] `Klekota, Justin, and Frederick P Roth.
56+
“Chemical substructures that enrich for biological activity.”
57+
Bioinformatics (Oxford, England) vol. 24,21 (2008): 2518-25.
58+
<https://pubmed.ncbi.nlm.nih.gov/18784118/>`_
59+
60+
.. [2] `Chemistry Development Kit (CDK) KlekotaRothFingerprinter
61+
<https://cdk.github.io/cdk/latest/docs/api/org/openscience/cdk/fingerprint/KlekotaRothFingerprinter.html>`_
62+
63+
Examples
64+
--------
65+
>>> from skfp.fingerprints import KlekotaRothFingerprint
66+
>>> smiles = ["O", "CC", "[C-]#N", "CC=O"]
67+
>>> fp = KlekotaRothFingerprint()
68+
>>> fp
69+
KlekotaRothFingerprint()
70+
71+
>>> fp.transform(smiles)
72+
array([[0, 0, 0, ..., 0, 0, 0],
73+
[0, 0, 0, ..., 0, 0, 0],
74+
[0, 0, 0, ..., 0, 0, 0],
75+
[0, 0, 0, ..., 0, 0, 0]], shape=(4, 4860), dtype=uint8)
76+
"""
77+
78+
def __init__(
79+
self,
80+
count: bool = False,
81+
sparse: bool = False,
82+
n_jobs: int | None = None,
83+
batch_size: int | None = None,
84+
verbose: int | dict = 0,
85+
):
86+
# note that those patterns were released as public domain:
87+
# https://github.com/cdk/cdk/blob/main/descriptor/fingerprint/src/main/java/org/openscience/cdk/fingerprint/KlekotaRothFingerprinter.java
88+
self._feature_names: list[str]
89+
self._pattern_atoms: dict[str, Mol]
90+
self._root: PatternNode
91+
92+
self._root, self._feature_names, self._pattern_atoms = _load_tree()
93+
super().__init__(
94+
patterns=self._feature_names,
95+
count=count,
96+
sparse=sparse,
97+
n_jobs=n_jobs,
98+
batch_size=batch_size,
99+
verbose=verbose,
100+
)
101+
102+
def get_feature_names_out(self, input_features=None) -> np.ndarray: # noqa: ARG002
103+
"""
104+
Get fingerprint output feature names. They are raw SMARTS patterns
105+
used as feature definitions.
106+
107+
Parameters
108+
----------
109+
input_features : array-like of str or None, default=None
110+
Unused, kept for scikit-learn compatibility.
111+
112+
Returns
113+
-------
114+
feature_names_out : ndarray of str objects
115+
Klekota-Roth feature names.
116+
"""
117+
return np.asarray(self._feature_names, dtype=object)
118+
119+
def transform(
120+
self, X: Sequence[str | Mol], copy: bool = False
121+
) -> np.ndarray | csr_array:
122+
"""
123+
Compute Klekota-Roth fingerprints.
124+
125+
Parameters
126+
----------
127+
X : {sequence, array-like} of shape (n_samples,)
128+
Sequence containing SMILES strings or RDKit ``Mol`` objects.
129+
130+
copy : bool, default=False
131+
Copy the input X or not.
132+
133+
Returns
134+
-------
135+
X : {ndarray, sparse matrix} of shape (n_samples, 4860)
136+
Array with fingerprints.
137+
"""
138+
return super().transform(X, copy)
139+
140+
def _calculate_fingerprint(self, X: Sequence[str | Mol]) -> np.ndarray | csr_array:
141+
X = ensure_mols(X)
142+
143+
n_bits = self.n_features_out
144+
bits = np.zeros((len(X), n_bits), dtype=np.uint32 if self.count else np.uint8)
145+
root_children = self._root.children
146+
147+
if self.count:
148+
set_value = lambda mol, pattern: len(mol.GetSubstructMatches(pattern))
149+
else:
150+
set_value = lambda _mol, _pattern: 1
151+
152+
for i, mol in enumerate(X):
153+
stack: deque[PatternNode] = deque(root_children)
154+
atom_contents = self._count_atom_patterns(mol)
155+
while stack:
156+
node = stack.pop()
157+
158+
if any(
159+
atom_contents[key] < val
160+
for key, val in node.atom_requirements.items()
161+
):
162+
continue
163+
164+
if not mol.HasSubstructMatch(node.pattern_mol):
165+
continue
166+
167+
if node.is_terminal:
168+
bits[i][node.feature_bit] = set_value(mol, node.pattern_mol)
169+
170+
stack.extend(node.children)
171+
172+
return csr_array(bits) if self.sparse else bits
173+
174+
def _count_atom_patterns(self, mol: Mol) -> dict[str, int]:
175+
"""
176+
Count occurrences of atom-level patterns in a molecule.
177+
"""
178+
atom_contents = dict.fromkeys(self._pattern_atoms, 0)
179+
for atom in mol.GetAtoms():
180+
symbol = atom.GetSymbol()
181+
atomic_num = atom.GetAtomicNum()
182+
hcount = atom.GetTotalNumHs()
183+
charge = atom.GetFormalCharge()
184+
aromatic = atom.GetIsAromatic()
185+
186+
symbol = symbol.lower() if aromatic else symbol
187+
188+
# plain element symbol
189+
if symbol in atom_contents:
190+
atom_contents[symbol] += 1
191+
192+
# atomic number pattern
193+
key = f"[#{atomic_num}]"
194+
if key in atom_contents:
195+
atom_contents[key] += 1
196+
197+
# hydrogen count pattern
198+
key = f"[{symbol}&H{hcount}]"
199+
if key in atom_contents:
200+
atom_contents[key] += 1
201+
202+
# charge pattern
203+
if charge != 0:
204+
sign = "+" if charge > 0 else "-"
205+
key = f"[{symbol}&{sign}]"
206+
if key in atom_contents:
207+
atom_contents[key] += 1
208+
209+
# negation of hydrogen
210+
if atomic_num != 1:
211+
atom_contents["[!#1]"] += 1
212+
213+
return atom_contents
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import json
2+
from dataclasses import dataclass, field
3+
from functools import lru_cache
4+
from pathlib import Path
5+
6+
from rdkit import Chem
7+
from rdkit.Chem import Mol
8+
9+
_TREE_PATH = Path(__file__).parent / "tree_data.json"
10+
11+
12+
@dataclass(slots=True, frozen=True)
13+
class PatternNode:
14+
"""
15+
Node in the SMARTS pattern tree.
16+
17+
Attributes
18+
----------
19+
smarts : str = None
20+
SMARTS string defining the pattern.
21+
22+
pattern_mol : Mol = None
23+
RDKit Mol object of the pattern.
24+
25+
is_terminal : bool = False
26+
Whether this node corresponds to a complete pattern or just a prefix.
27+
28+
feature_bit : int = None
29+
Index of the corresponding fingerprint bit.
30+
31+
children : list[_PatternNode] = []
32+
Child nodes.
33+
34+
atom_requirements : defaultdict[str, int]
35+
Minimal atom requirements needed to match at this node.
36+
"""
37+
38+
smarts: str | None = None
39+
pattern_mol: Mol | None = None
40+
is_terminal: bool = False
41+
feature_bit: int | None = None
42+
atom_requirements: dict[str, int] = field(default_factory=dict)
43+
children: list["PatternNode"] = field(default_factory=list)
44+
45+
46+
def _dict_to_node(d: dict, feature_names: list[str]) -> PatternNode:
47+
"""
48+
Recursively convert a dict representation of a pattern tree
49+
into a _PatternNode tree.
50+
"""
51+
node = PatternNode(
52+
smarts=d.get("smarts"),
53+
pattern_mol=Chem.MolFromSmarts(d.get("smarts") or ""),
54+
is_terminal=d.get("is_terminal", False),
55+
feature_bit=d.get("feature_bit"),
56+
atom_requirements=d.get("atom_requirements", {}),
57+
children=[
58+
_dict_to_node(node_dict, feature_names)
59+
for node_dict in d.get("children", [])
60+
],
61+
)
62+
63+
if node.is_terminal:
64+
# node.smarts and node.feature_bit can be None only in non-terminal nodes.
65+
# These values are set when generating the tree.
66+
feature_names[int(node.feature_bit)] = node.smarts # type: ignore
67+
68+
return node
69+
70+
71+
@lru_cache(maxsize=1)
72+
def _load_tree() -> tuple[PatternNode, list[str], dict[str, Mol]]:
73+
"""
74+
Load the pattern tree from a JSON file into internal representation.
75+
"""
76+
path = _TREE_PATH
77+
if not path.exists():
78+
raise FileNotFoundError("Klekota-Roth SMARTS tree file not found")
79+
80+
with path.open("r", encoding="utf-8") as file:
81+
data = json.load(file)
82+
83+
feature_names = [""] * data["n_terminal_nodes"]
84+
pattern_atoms = {key: Chem.MolFromSmarts(key) for key in data["atoms"]}
85+
root = _dict_to_node(data["tree"], feature_names)
86+
return root, feature_names, pattern_atoms

skfp/fingerprints/klekota_roth/tree_data.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)