1+ import warnings
12from abc import ABC , abstractmethod
23from collections .abc import Sequence
34from copy import deepcopy
78from joblib import effective_n_jobs
89from rdkit .Chem import Mol
910from sklearn .base import BaseEstimator , TransformerMixin
10- from sklearn .utils ._param_validation import InvalidParameterError
11+ from sklearn .utils ._param_validation import InvalidParameterError , StrOptions
1112from tqdm import tqdm
1213
1314from skfp .utils import ensure_mols , run_in_parallel
@@ -37,10 +38,25 @@ class BaseFilter(ABC, BaseEstimator, TransformerMixin):
3738 Whether to allow violating one of the rules for a molecule. This makes the
3839 filter less restrictive.
3940
41+ return_type : {"mol", "indicators", "condition_indicators"}, default="mol"
42+ What values to return as the filtering result.
43+
44+ - ``"mol"`` - return a list of molecules remaining in the dataset after filtering
45+ - ``"indicators"`` - return a binary vector with indicators which molecules pass
46+ the filter (1) and which would be removed (0)
47+ - ``"condition_indicators"`` - return a Pandas DataFrame with molecules in rows,
48+ filter conditions in columns, and 0/1 indicators whether a given condition was
49+ fulfilled by a given molecule
50+
4051 return_indicators : bool, default=False
4152 Whether to return a binary vector with indicators which molecules pass the
4253 filter, instead of list of molecules.
4354
55+ .. deprecated:: 1.17
56+ ``return_indicators`` is deprecated and will be removed in version 2.0.
57+ Use ``return_type`` instead. If ``return_indicators`` is set to ``True``,
58+ it will take precedence over ``return_type``.
59+
4460 n_jobs : int, default=None
4561 The number of jobs to run in parallel. :meth:`transform_x_y` and
4662 :meth:`transform` are parallelized over the input molecules. ``None`` means 1
@@ -60,6 +76,7 @@ class BaseFilter(ABC, BaseEstimator, TransformerMixin):
6076 # parameters common for all filters
6177 _parameter_constraints : dict = {
6278 "allow_one_violation" : ["boolean" ],
79+ "return_type" : [StrOptions ({"mol" , "indicators" , "condition_indicators" })],
6380 "return_indicators" : ["boolean" ],
6481 "n_jobs" : [Integral , None ],
6582 "batch_size" : [Integral , None ],
@@ -69,24 +86,55 @@ class BaseFilter(ABC, BaseEstimator, TransformerMixin):
6986 def __init__ (
7087 self ,
7188 allow_one_violation : bool = False ,
89+ return_type : str = "mol" ,
7290 return_indicators : bool = False ,
7391 n_jobs : int | None = None ,
7492 batch_size : int | None = None ,
7593 verbose : int | dict = 0 ,
7694 ):
7795 self .allow_one_violation = allow_one_violation
96+ self .return_type = return_type
7897 self .return_indicators = return_indicators
7998 self .n_jobs = n_jobs
8099 self .batch_size = batch_size
81100 self .verbose = verbose
82101
102+ if return_indicators :
103+ warnings .warn (
104+ "return_indicators is deprecated and will be removed in 2.0, "
105+ "use return_type instead"
106+ )
107+
83108 def __sklearn_is_fitted__ (self ) -> bool :
84109 """
85110 Unused, kept for scikit-learn compatibility. This class assumes stateless
86111 transformers and always returns True.
87112 """
88113 return True
89114
115+ def get_feature_names_out (self , input_features = None ) -> np .ndarray :
116+ """
117+ Get filter condition names. They correspond to molecular descriptors (for
118+ physicochemical filters) or SMARTS patterns (for substructural filters).
119+
120+ Parameters
121+ ----------
122+ input_features : array-like of str or None, default=None
123+ Unused, kept for scikit-learn compatibility.
124+
125+ Returns
126+ -------
127+ feature_names_out : ndarray of str objects
128+ Filter condition names.
129+ """
130+ if not hasattr (self , "_condition_names" ):
131+ raise AttributeError (
132+ f"Filter condition names not yet supported for "
133+ f"{ self .__class__ .__name__ } "
134+ )
135+
136+ return np .array (self ._condition_names )
137+
90138 def fit (self , X : Sequence [str | Mol ], y : np .ndarray | None = None ):
91139 """Unused, kept for scikit-learn compatibility.
92140
@@ -133,40 +181,43 @@ def transform(
133181 self , X : Sequence [str | Mol ], copy : bool = False
134182 ) -> list [str | Mol ] | np .ndarray :
135183 """
136- Apply a filter to input molecules. Output depends on ``return_indicators ``
184+ Apply a filter to input molecules. Output depends on ``return_type ``
137185 attribute.
138186
139187 Parameters
140188 ----------
141- X : {sequence, array-like} of shape (n_samples,)
142- Sequence containing RDKit ``Mol`` objects.
189+ X : {sequence of str or Mol}
190+ Sequence containing SMILES strings or RDKit ``Mol`` objects.
143191
144192 copy : bool, default=False
145193 Copy the input X or not.
146194
147195 Returns
148196 -------
149- X : list of shape (n_samples_conf_gen ,) or array of shape (n_samples,)
150- List with filtered molecules, or indicator vector which molecules
151- fulfill the filter rules .
197+ X : list of shape (n_samples ,) or array of shape (n_samples,)
198+ or array of shape (n_samples, n_conditions)
199+ List with filtered molecules or indicators .
152200 """
153201 filter_ind = self ._get_filter_indicators (X , copy )
202+
154203 if self .return_indicators :
155204 return filter_ind
156- else :
205+ elif self . return_type == "mol" :
157206 return [mol for idx , mol in enumerate (X ) if filter_ind [idx ]]
207+ else :
208+ return filter_ind
158209
159210 def transform_x_y (
160211 self , X : Sequence [str | Mol ], y : np .ndarray , copy : bool = False
161212 ) -> tuple [list [str | Mol ], np .ndarray ] | tuple [np .ndarray , np .ndarray ]:
162213 """
163- Apply a filter to input molecules. Output depends on ``return_indicators ``
214+ Apply a filter to input molecules. Output depends on ``return_type ``
164215 attribute.
165216
166217 Parameters
167218 ----------
168- X : {sequence, array-like} of shape (n_samples,)
169- Sequence containing RDKit ``Mol`` objects.
219+ X : {sequence of str or Mol}
220+ Sequence containing SMILES strings or RDKit ``Mol`` objects.
170221
171222 y : array-like of shape (n_samples,)
172223 Array with labels for molecules.
@@ -176,27 +227,29 @@ def transform_x_y(
176227
177228 Returns
178229 -------
179- X : list of shape (n_samples_conf_gen ,) or array of shape (n_samples,)
180- List with filtered molecules, or indicator vector which molecules
181- fulfill the filter rules .
230+ X : list of shape (n_samples ,) or array of shape (n_samples,)
231+ or array of shape (n_samples, n_conditions)
232+ List with filtered molecules or indicators .
182233
183- y : np.ndarray of shape (n_samples_conf_gen ,)
234+ y : np.ndarray of shape (n_samples ,)
184235 Array with labels for molecules.
185236 """
186237 filter_ind = self ._get_filter_indicators (X , copy )
238+
187239 if self .return_indicators :
188240 return filter_ind , y
189- else :
241+ elif self . return_type == "mol" :
190242 mols = [mol for idx , mol in enumerate (X ) if filter_ind [idx ]]
191243 y = y [filter_ind ]
192244 return mols , y
245+ else :
246+ return filter_ind , y
193247
194248 def _get_filter_indicators (
195249 self , mols : Sequence [str | Mol ], copy : bool
196250 ) -> np .ndarray :
197251 self ._validate_params ()
198252 mols = deepcopy (mols ) if copy else mols
199- mols = ensure_mols (mols )
200253
201254 n_jobs = effective_n_jobs (self .n_jobs )
202255 if n_jobs == 1 :
@@ -207,23 +260,38 @@ def _get_filter_indicators(
207260 else :
208261 filter_indicators = self ._filter_mols_batch (mols )
209262 else :
263+ flatten_results = self .return_type != "condition_indicators"
264+
210265 filter_indicators = run_in_parallel (
211266 self ._filter_mols_batch ,
212267 data = mols ,
213268 n_jobs = n_jobs ,
214269 batch_size = self .batch_size ,
215- flatten_results = True ,
270+ flatten_results = flatten_results ,
216271 verbose = self .verbose ,
217272 )
218273
274+ if self .return_type == "condition_indicators" :
275+ filter_indicators = np .vstack (filter_indicators )
276+
219277 return filter_indicators
220278
221- def _filter_mols_batch (self , mols : list [Mol ]) -> np .ndarray :
279+ def _filter_mols_batch (self , mols : Sequence [str | Mol ]) -> np .ndarray :
280+ mols = ensure_mols (mols )
281+
222282 filter_indicators = [self ._apply_mol_filter (mol ) for mol in mols ]
223- return np .array (filter_indicators , dtype = bool )
283+
284+ if self .return_indicators :
285+ filter_indicators = np .array (filter_indicators , dtype = bool )
286+ elif self .return_type == "condition_indicators" :
287+ filter_indicators = np .vstack (filter_indicators )
288+ else :
289+ filter_indicators = np .array (filter_indicators , dtype = bool )
290+
291+ return filter_indicators
224292
225293 @abstractmethod
226- def _apply_mol_filter (self , mol : Mol ) -> bool :
294+ def _apply_mol_filter (self , mol : Mol ) -> bool | np . ndarray :
227295 pass
228296
229297 def _validate_params (self ) -> None :
0 commit comments