Skip to content

Commit 8390014

Browse files
InChI string support in ensure_mols (#522)
1 parent 9527e92 commit 8390014

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

skfp/utils/validators.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,28 @@
22
from collections.abc import Callable, Sequence
33
from typing import Any
44

5-
from rdkit.Chem import Mol, MolFromSmiles, MolToSmiles
5+
from rdkit.Chem import Mol, MolFromInchi, MolFromSmiles, MolToSmiles
66
from rdkit.Chem.PropertyMol import PropertyMol
77

88

99
def ensure_mols(X: Sequence[Any]) -> list[Mol]:
1010
"""
1111
Ensure that all input sequence elements are RDKit ``Mol`` objects. Requires
12-
all input elements to be of the same type: string (SMILES strings) or ``Mol``.
13-
In the case of SMILES strings, they are converted to RDKit ``Mol`` objects with
12+
all input elements to be of the same type: string (SMILES or InChI strings) or ``Mol``.
13+
In the case of SMILES or InChI strings, they are converted to RDKit ``Mol`` objects with
1414
default settings.
1515
"""
1616
if not all(isinstance(x, (Mol, PropertyMol, str)) for x in X):
1717
types = {type(x) for x in X}
1818
raise TypeError(
19-
f"Passed values must be RDKit Mol objects or SMILES strings, got types: {types}"
19+
f"Passed values must be RDKit Mol objects, SMILES or InChI strings, got types: {types}"
2020
)
2121

22-
mols = [MolFromSmiles(x) if isinstance(x, str) else x for x in X]
22+
if isinstance(X[0], str):
23+
parser = MolFromInchi if X[0].startswith("InChI=") else MolFromSmiles
24+
mols = [parser(x) for x in X]
25+
else:
26+
mols = list(X)
2327

2428
if any(x is None for x in mols):
2529
idx = mols.index(None)

tests/utils/validators.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,27 @@ def test_ensure_mols_wrong_smiles():
2929
assert "at index 1 as molecule" in str(exc_info)
3030

3131

32+
def test_ensure_mols_valid_inchi():
33+
inchi_list = ["InChI=1S/H2O/h1H2", "InChI=1S/CH4/h1H4"]
34+
mols = ensure_mols(inchi_list)
35+
assert all(m is not None for m in mols)
36+
assert len(mols) == 2
37+
from rdkit.Chem import MolToSmiles
38+
39+
smiles = [MolToSmiles(m) for m in mols]
40+
assert "O" in smiles
41+
assert "C" in smiles
42+
43+
44+
def test_ensure_mols_invalid_inchi():
45+
inchi_list = ["InChI=1S/H2O/h1H2", "InChI=1S/invalid"]
46+
with pytest.raises(TypeError) as exc_info:
47+
ensure_mols(inchi_list)
48+
49+
assert "Could not parse" in str(exc_info)
50+
assert "at index 1 as molecule" in str(exc_info)
51+
52+
3253
def test_ensure_mols_in_fingerprint():
3354
smiles_list = ["O", "O=N([O-])C1=C(CN=C1NCCSCc2ncccc2)Cc3ccccc3"]
3455
fp = AtomPairFingerprint()

0 commit comments

Comments
 (0)