Skip to content

Commit e412750

Browse files
authored
Filter condition indicators for physicochemical filters (#471)
1 parent 64560bf commit e412750

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1662
-151
lines changed

skfp/bases/base_filter.py

Lines changed: 89 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from abc import ABC, abstractmethod
23
from collections.abc import Sequence
34
from copy import deepcopy
@@ -7,7 +8,7 @@
78
from joblib import effective_n_jobs
89
from rdkit.Chem import Mol
910
from sklearn.base import BaseEstimator, TransformerMixin
10-
from sklearn.utils._param_validation import InvalidParameterError
11+
from sklearn.utils._param_validation import InvalidParameterError, StrOptions
1112
from tqdm import tqdm
1213

1314
from skfp.utils import ensure_mols, run_in_parallel
@@ -37,10 +38,25 @@ class BaseFilter(ABC, BaseEstimator, TransformerMixin):
3738
Whether to allow violating one of the rules for a molecule. This makes the
3839
filter less restrictive.
3940
41+
return_type : {"mol", "indicators", "condition_indicators"}, default="mol"
42+
What values to return as the filtering result.
43+
44+
- ``"mol"`` - return a list of molecules remaining in the dataset after filtering
45+
- ``"indicators"`` - return a binary vector with indicators which molecules pass
46+
the filter (1) and which would be removed (0)
47+
- ``"condition_indicators"`` - return a Pandas DataFrame with molecules in rows,
48+
filter conditions in columns, and 0/1 indicators whether a given condition was
49+
fulfilled by a given molecule
50+
4051
return_indicators : bool, default=False
4152
Whether to return a binary vector with indicators which molecules pass the
4253
filter, instead of list of molecules.
4354
55+
.. deprecated:: 1.17
56+
``return_indicators`` is deprecated and will be removed in version 2.0.
57+
Use ``return_type`` instead. If ``return_indicators`` is set to ``True``,
58+
it will take precedence over ``return_type``.
59+
4460
n_jobs : int, default=None
4561
The number of jobs to run in parallel. :meth:`transform_x_y` and
4662
:meth:`transform` are parallelized over the input molecules. ``None`` means 1
@@ -60,6 +76,7 @@ class BaseFilter(ABC, BaseEstimator, TransformerMixin):
6076
# parameters common for all filters
6177
_parameter_constraints: dict = {
6278
"allow_one_violation": ["boolean"],
79+
"return_type": [StrOptions({"mol", "indicators", "condition_indicators"})],
6380
"return_indicators": ["boolean"],
6481
"n_jobs": [Integral, None],
6582
"batch_size": [Integral, None],
@@ -69,24 +86,55 @@ class BaseFilter(ABC, BaseEstimator, TransformerMixin):
6986
def __init__(
7087
self,
7188
allow_one_violation: bool = False,
89+
return_type: str = "mol",
7290
return_indicators: bool = False,
7391
n_jobs: int | None = None,
7492
batch_size: int | None = None,
7593
verbose: int | dict = 0,
7694
):
7795
self.allow_one_violation = allow_one_violation
96+
self.return_type = return_type
7897
self.return_indicators = return_indicators
7998
self.n_jobs = n_jobs
8099
self.batch_size = batch_size
81100
self.verbose = verbose
82101

102+
if return_indicators:
103+
warnings.warn(
104+
"return_indicators is deprecated and will be removed in 2.0, "
105+
"use return_type instead"
106+
)
107+
83108
def __sklearn_is_fitted__(self) -> bool:
84109
"""
85110
Unused, kept for scikit-learn compatibility. This class assumes stateless
86111
transformers and always returns True.
87112
"""
88113
return True
89114

115+
def get_feature_names_out(self, input_features=None) -> np.ndarray:
116+
"""
117+
Get filter condition names. They correspond to molecular descriptors (for
118+
physicochemical filters) or SMARTS patterns (for substructural filters).
119+
120+
Parameters
121+
----------
122+
input_features : array-like of str or None, default=None
123+
Unused, kept for scikit-learn compatibility.
124+
125+
Returns
126+
-------
127+
feature_names_out : ndarray of str objects
128+
Filter condition names.
129+
"""
130+
if not hasattr(self, "_condition_names"):
131+
raise AttributeError(
132+
f"Filter condition names not yet supported for "
133+
f"{self.__class__.__name__}"
134+
)
135+
136+
return np.array(self._condition_names)
137+
90138
def fit(self, X: Sequence[str | Mol], y: np.ndarray | None = None):
91139
"""Unused, kept for scikit-learn compatibility.
92140
@@ -133,40 +181,43 @@ def transform(
133181
self, X: Sequence[str | Mol], copy: bool = False
134182
) -> list[str | Mol] | np.ndarray:
135183
"""
136-
Apply a filter to input molecules. Output depends on ``return_indicators``
184+
Apply a filter to input molecules. Output depends on ``return_type``
137185
attribute.
138186
139187
Parameters
140188
----------
141-
X : {sequence, array-like} of shape (n_samples,)
142-
Sequence containing RDKit ``Mol`` objects.
189+
X : {sequence of str or Mol}
190+
Sequence containing SMILES strings or RDKit ``Mol`` objects.
143191
144192
copy : bool, default=False
145193
Copy the input X or not.
146194
147195
Returns
148196
-------
149-
X : list of shape (n_samples_conf_gen,) or array of shape (n_samples,)
150-
List with filtered molecules, or indicator vector which molecules
151-
fulfill the filter rules.
197+
X : list of shape (n_samples,) or array of shape (n_samples,)
198+
or array of shape (n_samples, n_conditions)
199+
List with filtered molecules or indicators.
152200
"""
153201
filter_ind = self._get_filter_indicators(X, copy)
202+
154203
if self.return_indicators:
155204
return filter_ind
156-
else:
205+
elif self.return_type == "mol":
157206
return [mol for idx, mol in enumerate(X) if filter_ind[idx]]
207+
else:
208+
return filter_ind
158209

159210
def transform_x_y(
160211
self, X: Sequence[str | Mol], y: np.ndarray, copy: bool = False
161212
) -> tuple[list[str | Mol], np.ndarray] | tuple[np.ndarray, np.ndarray]:
162213
"""
163-
Apply a filter to input molecules. Output depends on ``return_indicators``
214+
Apply a filter to input molecules. Output depends on ``return_type``
164215
attribute.
165216
166217
Parameters
167218
----------
168-
X : {sequence, array-like} of shape (n_samples,)
169-
Sequence containing RDKit ``Mol`` objects.
219+
X : {sequence of str or Mol}
220+
Sequence containing SMILES strings or RDKit ``Mol`` objects.
170221
171222
y : array-like of shape (n_samples,)
172223
Array with labels for molecules.
@@ -176,27 +227,29 @@ def transform_x_y(
176227
177228
Returns
178229
-------
179-
X : list of shape (n_samples_conf_gen,) or array of shape (n_samples,)
180-
List with filtered molecules, or indicator vector which molecules
181-
fulfill the filter rules.
230+
X : list of shape (n_samples,) or array of shape (n_samples,)
231+
or array of shape (n_samples, n_conditions)
232+
List with filtered molecules or indicators.
182233
183-
y : np.ndarray of shape (n_samples_conf_gen,)
234+
y : np.ndarray of shape (n_samples,)
184235
Array with labels for molecules.
185236
"""
186237
filter_ind = self._get_filter_indicators(X, copy)
238+
187239
if self.return_indicators:
188240
return filter_ind, y
189-
else:
241+
elif self.return_type == "mol":
190242
mols = [mol for idx, mol in enumerate(X) if filter_ind[idx]]
191243
y = y[filter_ind]
192244
return mols, y
245+
else:
246+
return filter_ind, y
193247

194248
def _get_filter_indicators(
195249
self, mols: Sequence[str | Mol], copy: bool
196250
) -> np.ndarray:
197251
self._validate_params()
198252
mols = deepcopy(mols) if copy else mols
199-
mols = ensure_mols(mols)
200253

201254
n_jobs = effective_n_jobs(self.n_jobs)
202255
if n_jobs == 1:
@@ -207,23 +260,38 @@ def _get_filter_indicators(
207260
else:
208261
filter_indicators = self._filter_mols_batch(mols)
209262
else:
263+
flatten_results = self.return_type != "condition_indicators"
264+
210265
filter_indicators = run_in_parallel(
211266
self._filter_mols_batch,
212267
data=mols,
213268
n_jobs=n_jobs,
214269
batch_size=self.batch_size,
215-
flatten_results=True,
270+
flatten_results=flatten_results,
216271
verbose=self.verbose,
217272
)
218273

274+
if self.return_type == "condition_indicators":
275+
filter_indicators = np.vstack(filter_indicators)
276+
219277
return filter_indicators
220278

221-
def _filter_mols_batch(self, mols: list[Mol]) -> np.ndarray:
279+
def _filter_mols_batch(self, mols: Sequence[str | Mol]) -> np.ndarray:
280+
mols = ensure_mols(mols)
281+
222282
filter_indicators = [self._apply_mol_filter(mol) for mol in mols]
223-
return np.array(filter_indicators, dtype=bool)
283+
284+
if self.return_indicators:
285+
filter_indicators = np.array(filter_indicators, dtype=bool)
286+
elif self.return_type == "condition_indicators":
287+
filter_indicators = np.vstack(filter_indicators)
288+
else:
289+
filter_indicators = np.array(filter_indicators, dtype=bool)
290+
291+
return filter_indicators
224292

225293
@abstractmethod
226-
def _apply_mol_filter(self, mol: Mol) -> bool:
294+
def _apply_mol_filter(self, mol: Mol) -> bool | np.ndarray:
227295
pass
228296

229297
def _validate_params(self) -> None:

skfp/filters/beyond_ro5.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from rdkit.Chem import Mol
23
from rdkit.Chem.Crippen import MolLogP
34
from rdkit.Chem.Descriptors import MolWt
@@ -37,10 +38,25 @@ class BeyondRo5Filter(BaseFilter):
3738
Whether to allow violating one of the rules for a molecule. This makes the
3839
filter less restrictive.
3940
41+
return_type : {"mol", "indicators", "condition_indicators"}, default="mol"
42+
What values to return as the filtering result.
43+
44+
- ``"mol"`` - return a list of molecules remaining in the dataset after filtering
45+
- ``"indicators"`` - return a binary vector with indicators which molecules pass
46+
the filter (1) and which would be removed (0)
47+
- ``"condition_indicators"`` - return a Pandas DataFrame with molecules in rows,
48+
filter conditions in columns, and 0/1 indicators whether a given condition was
49+
fulfilled by a given molecule
50+
4051
return_indicators : bool, default=False
4152
Whether to return a binary vector with indicators which molecules pass the
4253
filter, instead of list of molecules.
4354
55+
.. deprecated:: 1.17
56+
``return_indicators`` is deprecated and will be removed in version 2.0.
57+
Use ``return_type`` instead. If ``return_indicators`` is set to ``True``,
58+
it will take precedence over ``return_type``.
59+
4460
n_jobs : int, default=None
4561
The number of jobs to run in parallel. :meth:`transform_x_y` and
4662
:meth:`transform` are parallelized over the input molecules. ``None`` means 1
@@ -85,20 +101,30 @@ class BeyondRo5Filter(BaseFilter):
85101
def __init__(
86102
self,
87103
allow_one_violation: bool = False,
104+
return_type: str = "mol",
88105
return_indicators: bool = False,
89106
n_jobs: int | None = None,
90107
batch_size: int | None = None,
91108
verbose: int | dict = 0,
92109
):
93110
super().__init__(
94111
allow_one_violation=allow_one_violation,
112+
return_type=return_type,
95113
return_indicators=return_indicators,
96114
n_jobs=n_jobs,
97115
batch_size=batch_size,
98116
verbose=verbose,
99117
)
118+
self._condition_names = [
119+
"MolWeight <= 1000",
120+
"-2 <= logP <= 10",
121+
"HBA <= 15",
122+
"HBD <= 6",
123+
"TPSA <= 250",
124+
"rotatable bonds <= 6",
125+
]
100126

101-
def _apply_mol_filter(self, mol: Mol) -> bool:
127+
def _apply_mol_filter(self, mol: Mol) -> bool | np.ndarray:
102128
rules = [
103129
MolWt(mol) <= 1000,
104130
-2 <= MolLogP(mol) <= 10,
@@ -107,6 +133,10 @@ def _apply_mol_filter(self, mol: Mol) -> bool:
107133
CalcTPSA(mol) <= 250,
108134
CalcNumRotatableBonds(mol) <= 20,
109135
]
136+
137+
if self.return_type == "condition_indicators":
138+
return np.array(rules, dtype=bool)
139+
110140
passed_rules = sum(rules)
111141

112142
if self.allow_one_violation:

0 commit comments

Comments
 (0)