Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions skfp/utils/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,28 @@
from collections.abc import Callable, Sequence
from typing import Any

from rdkit.Chem import Mol, MolFromSmiles, MolToSmiles
from rdkit.Chem import Mol, MolFromInchi, MolFromSmiles, MolToSmiles
from rdkit.Chem.PropertyMol import PropertyMol


def ensure_mols(X: Sequence[Any]) -> list[Mol]:
"""
Ensure that all input sequence elements are RDKit ``Mol`` objects. Requires
all input elements to be of the same type: string (SMILES strings) or ``Mol``.
In the case of SMILES strings, they are converted to RDKit ``Mol`` objects with
all input elements to be of the same type: string (SMILES or InChI strings) or ``Mol``.
In the case of SMILES or InChI strings, they are converted to RDKit ``Mol`` objects with
default settings.
"""
if not all(isinstance(x, (Mol, PropertyMol, str)) for x in X):
types = {type(x) for x in X}
raise TypeError(
f"Passed values must be RDKit Mol objects or SMILES strings, got types: {types}"
f"Passed values must be RDKit Mol objects, SMILES or InChI strings, got types: {types}"
)

mols = [MolFromSmiles(x) if isinstance(x, str) else x for x in X]
if isinstance(X[0], str):
parser = MolFromInchi if X[0].startswith("InChI=") else MolFromSmiles
mols = [parser(x) for x in X]
else:
mols = list(X)

if any(x is None for x in mols):
idx = mols.index(None)
Expand Down
21 changes: 21 additions & 0 deletions tests/utils/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,27 @@ def test_ensure_mols_wrong_smiles():
assert "at index 1 as molecule" in str(exc_info)


def test_ensure_mols_valid_inchi():
inchi_list = ["InChI=1S/H2O/h1H2", "InChI=1S/CH4/h1H4"]
mols = ensure_mols(inchi_list)
assert all(m is not None for m in mols)
assert len(mols) == 2
from rdkit.Chem import MolToSmiles

smiles = [MolToSmiles(m) for m in mols]
assert "O" in smiles
assert "C" in smiles


def test_ensure_mols_invalid_inchi():
inchi_list = ["InChI=1S/H2O/h1H2", "InChI=1S/invalid"]
with pytest.raises(TypeError) as exc_info:
ensure_mols(inchi_list)

assert "Could not parse" in str(exc_info)
assert "at index 1 as molecule" in str(exc_info)


def test_ensure_mols_in_fingerprint():
smiles_list = ["O", "O=N([O-])C1=C(CN=C1NCCSCc2ncccc2)Cc3ccccc3"]
fp = AtomPairFingerprint()
Expand Down