diff --git a/skfp/utils/validators.py b/skfp/utils/validators.py index af53444c..d01f77d3 100644 --- a/skfp/utils/validators.py +++ b/skfp/utils/validators.py @@ -2,24 +2,28 @@ from collections.abc import Callable, Sequence from typing import Any -from rdkit.Chem import Mol, MolFromSmiles, MolToSmiles +from rdkit.Chem import Mol, MolFromInchi, MolFromSmiles, MolToSmiles from rdkit.Chem.PropertyMol import PropertyMol def ensure_mols(X: Sequence[Any]) -> list[Mol]: """ Ensure that all input sequence elements are RDKit ``Mol`` objects. Requires - all input elements to be of the same type: string (SMILES strings) or ``Mol``. - In the case of SMILES strings, they are converted to RDKit ``Mol`` objects with + all input elements to be of the same type: string (SMILES or InChI strings) or ``Mol``. + In the case of SMILES or InChI strings, they are converted to RDKit ``Mol`` objects with default settings. """ if not all(isinstance(x, (Mol, PropertyMol, str)) for x in X): types = {type(x) for x in X} raise TypeError( - f"Passed values must be RDKit Mol objects or SMILES strings, got types: {types}" + f"Passed values must be RDKit Mol objects, SMILES or InChI strings, got types: {types}" ) - mols = [MolFromSmiles(x) if isinstance(x, str) else x for x in X] + if isinstance(X[0], str): + parser = MolFromInchi if X[0].startswith("InChI=") else MolFromSmiles + mols = [parser(x) for x in X] + else: + mols = list(X) if any(x is None for x in mols): idx = mols.index(None) diff --git a/tests/utils/validators.py b/tests/utils/validators.py index 983131d4..791238df 100644 --- a/tests/utils/validators.py +++ b/tests/utils/validators.py @@ -29,6 +29,27 @@ def test_ensure_mols_wrong_smiles(): assert "at index 1 as molecule" in str(exc_info) +def test_ensure_mols_valid_inchi(): + inchi_list = ["InChI=1S/H2O/h1H2", "InChI=1S/CH4/h1H4"] + mols = ensure_mols(inchi_list) + assert all(m is not None for m in mols) + assert len(mols) == 2 + from rdkit.Chem import MolToSmiles + + smiles = [MolToSmiles(m) for m in mols] + assert "O" in smiles + assert "C" in smiles + + +def test_ensure_mols_invalid_inchi(): + inchi_list = ["InChI=1S/H2O/h1H2", "InChI=1S/invalid"] + with pytest.raises(TypeError) as exc_info: + ensure_mols(inchi_list) + + assert "Could not parse" in str(exc_info) + assert "at index 1 as molecule" in str(exc_info) + + def test_ensure_mols_in_fingerprint(): smiles_list = ["O", "O=N([O-])C1=C(CN=C1NCCSCc2ncccc2)Cc3ccccc3"] fp = AtomPairFingerprint()