Skip to content

Commit d6a6afc

Browse files
authored
Add multithreading to MolFromSDFTransformer [Issue 467] (#526)
1 parent 0eb7623 commit d6a6afc

File tree

3 files changed

+142
-7
lines changed

3 files changed

+142
-7
lines changed

skfp/preprocessing/input_output/sdf.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import os.path
22
import warnings
33
from collections.abc import Sequence
4+
from numbers import Integral
45

6+
from joblib import effective_n_jobs
57
from rdkit.Chem import Mol, SDMolSupplier, SDWriter
68
from rdkit.Chem.PropertyMol import PropertyMol
79

810
from skfp.bases import BasePreprocessor
911
from skfp.utils import require_mols
12+
from skfp.utils.functions import _get_rdkit_version
13+
14+
_MIN_MULTITHREADED_SDF_VERSION = (2025, 9, 1)
1015

1116

1217
class MolFromSDFTransformer(BasePreprocessor):
@@ -29,6 +34,11 @@ class MolFromSDFTransformer(BasePreprocessor):
2934
Remove explicit hydrogens from the molecule where possible, using RDKit
3035
implicit hydrogens instead.
3136
37+
n_jobs : int, default=None
38+
The number of jobs to use when reading molecules from an SDF file path.
39+
If ``n_jobs > 1`` and the installed RDKit version is at least ``2025.09.1``
40+
the file is read in parallel. Raw SDF text input is always processed sequentially.
41+
3242
References
3343
----------
3444
.. [1] `RDKit SDMolSupplier documentation
@@ -50,14 +60,16 @@ class MolFromSDFTransformer(BasePreprocessor):
5060
_parameter_constraints: dict = {
5161
"sanitize": ["boolean"],
5262
"remove_hydrogens": ["boolean"],
63+
"n_jobs": [Integral, None],
5364
}
5465

5566
def __init__(
5667
self,
5768
sanitize: bool = True,
5869
remove_hydrogens: bool = True,
70+
n_jobs: int | None = None,
5971
):
60-
super().__init__()
72+
super().__init__(n_jobs=n_jobs)
6173
self.sanitize = sanitize
6274
self.remove_hydrogens = remove_hydrogens
6375

@@ -84,12 +96,9 @@ def transform(self, X: str, copy: bool = False) -> list[Mol]: # type: ignore[ov
8496
if not os.path.exists(X):
8597
raise FileNotFoundError(f"SDF file at path '{X}' not found")
8698

87-
with open(X) as file:
88-
X = file.read()
89-
90-
supplier = SDMolSupplier()
91-
supplier.SetData(X, sanitize=self.sanitize, removeHs=self.remove_hydrogens)
92-
mols = list(supplier)
99+
mols = self._read_sdf_file(X)
100+
else:
101+
mols = self._read_sdf_text(X)
93102

94103
if not mols:
95104
warnings.warn("No molecules detected in provided SDF file")
@@ -99,6 +108,61 @@ def transform(self, X: str, copy: bool = False) -> list[Mol]: # type: ignore[ov
99108
def _transform_batch(self, X):
100109
pass # unused
101110

111+
def _read_sdf_file(self, filepath: str) -> list[Mol]:
112+
n_jobs = effective_n_jobs(self.n_jobs)
113+
114+
if n_jobs > 1:
115+
rdkit_version = _get_rdkit_version()
116+
if rdkit_version < _MIN_MULTITHREADED_SDF_VERSION:
117+
warnings.warn(
118+
"Parallel SDF reading requires RDKit >= 2025.09.1. "
119+
f"Installed version is {'.'.join(map(str, rdkit_version))}. "
120+
"Falling back to sequential loading."
121+
)
122+
else:
123+
return self._read_sdf_file_parallel(filepath, n_jobs)
124+
125+
return list(
126+
SDMolSupplier(
127+
filepath,
128+
sanitize=self.sanitize,
129+
removeHs=self.remove_hydrogens,
130+
)
131+
)
132+
133+
def _read_sdf_file_parallel(self, filepath: str, n_jobs: int) -> list[Mol]:
134+
from rdkit.Chem import MultithreadedSDMolSupplier
135+
136+
with MultithreadedSDMolSupplier(
137+
filepath,
138+
sanitize=self.sanitize,
139+
removeHs=self.remove_hydrogens,
140+
numWriterThreads=n_jobs,
141+
) as supplier:
142+
mols_with_record_ids = [
143+
(supplier.GetLastRecordId(), mol)
144+
for mol in supplier
145+
if mol is not None # multithreaded supplier may yield None duplicates
146+
]
147+
148+
mols_with_record_ids.sort(key=lambda item: item[0])
149+
return [mol for _, mol in mols_with_record_ids]
150+
151+
def _read_sdf_text(self, sdf_text: str) -> list[Mol]:
152+
if effective_n_jobs(self.n_jobs) > 1:
153+
warnings.warn(
154+
"Parallel SDF reading requires a file path. Falling back to sequential "
155+
"loading for raw SDF text input."
156+
)
157+
158+
supplier = SDMolSupplier()
159+
supplier.SetData(
160+
sdf_text,
161+
sanitize=self.sanitize,
162+
removeHs=self.remove_hydrogens,
163+
)
164+
return list(supplier)
165+
102166

103167
class MolToSDFTransformer(BasePreprocessor):
104168
"""

skfp/utils/functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from importlib.metadata import version
33

44
import pandas as pd
5+
from rdkit import rdBase
56

67

78
def get_data_from_indices(data: Sequence, indices: Sequence[int]) -> list:
@@ -20,3 +21,13 @@ def _get_sklearn_version():
2021
sklearn_ver = version("scikit-learn") # e.g. 1.6.0
2122
sklearn_ver = ".".join(sklearn_ver.split(".")[:2]) # e.g. 1.6
2223
return float(sklearn_ver)
24+
25+
26+
def _get_rdkit_version() -> tuple[int, int, int]:
27+
# Unlike scikit-learn which uses float (broken for minor >= 10, e.g. 2025.1 == 2025.10),
28+
# we return a tuple for correct ordering.
29+
rdkit_ver = rdBase.rdkitVersion # e.g. "2025.09.3"
30+
parts = rdkit_ver.split(".")
31+
if len(parts) < 3:
32+
raise RuntimeError(f"Cannot parse RDKit version: {rdkit_ver}")
33+
return int(parts[0]), int(parts[1]), int(parts[2])

tests/preprocessing/input_output/sdf.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from rdkit.Chem import Mol
66

77
from skfp.preprocessing import MolFromSDFTransformer, MolToSDFTransformer
8+
from skfp.preprocessing.input_output import sdf as sdf_module
89

910

1011
@pytest.fixture
@@ -45,6 +46,65 @@ def test_mol_to_and_from_sdf(mols_list, sdf_out_file_path):
4546
assert all(isinstance(x, Mol) for x in mols_list_2)
4647

4748

49+
def test_mol_from_sdf_parallel_from_file(sdf_in_file_path):
50+
mol_from_sdf = MolFromSDFTransformer(n_jobs=2)
51+
mols = mol_from_sdf.transform(sdf_in_file_path)
52+
53+
assert_equal(len(mols), 1)
54+
assert all(isinstance(x, Mol) for x in mols)
55+
56+
57+
def test_mol_from_sdf_parallel_warns_for_raw_text(sdf_in_file_path):
58+
with open(sdf_in_file_path) as file:
59+
sdf_text = file.read()
60+
61+
mol_from_sdf = MolFromSDFTransformer(n_jobs=2)
62+
with pytest.warns(
63+
UserWarning,
64+
match="Parallel SDF reading requires a file path",
65+
):
66+
mols = mol_from_sdf.transform(sdf_text)
67+
68+
assert_equal(len(mols), 1)
69+
assert all(isinstance(x, Mol) for x in mols)
70+
71+
72+
def test_mol_from_sdf_parallel_preserves_order(mols_list, tmp_path):
73+
mols = []
74+
# add names for verification
75+
for idx, mol in enumerate(mols_list[:5]):
76+
mol_copy = Mol(mol)
77+
name = f"mol_{idx}"
78+
mol_copy.SetProp("_Name", name)
79+
mols.append(mol_copy)
80+
81+
sdf_file_path = tmp_path / "ordered_mols.sdf"
82+
MolToSDFTransformer(str(sdf_file_path)).transform(mols)
83+
84+
# test
85+
sequential_mols = MolFromSDFTransformer().transform(str(sdf_file_path))
86+
parallel_mols = MolFromSDFTransformer(n_jobs=2).transform(str(sdf_file_path))
87+
88+
sequential_names = [mol.GetProp("_Name") for mol in sequential_mols]
89+
parallel_names = [mol.GetProp("_Name") for mol in parallel_mols]
90+
91+
assert parallel_names == sequential_names
92+
93+
94+
def test_mol_from_sdf_parallel_falls_back_for_older_rdkit(monkeypatch):
95+
sentinel = object()
96+
monkeypatch.setattr(sdf_module, "_get_rdkit_version", lambda: (2025, 3, 0))
97+
monkeypatch.setattr(
98+
sdf_module, "SDMolSupplier", lambda *_args, **_kwargs: [sentinel]
99+
)
100+
101+
mol_from_sdf = MolFromSDFTransformer(n_jobs=2)
102+
with pytest.warns(UserWarning, match="requires RDKit >= 2025.09.1"):
103+
mols = mol_from_sdf._read_sdf_file("ignored.sdf")
104+
105+
assert mols == [sentinel]
106+
107+
48108
def test_error_nonexistent_sdf_file():
49109
mol_from_sdf = MolFromSDFTransformer()
50110
with pytest.raises(FileNotFoundError):

0 commit comments

Comments
 (0)