|
2 | 2 | from collections.abc import Callable, Sequence |
3 | 3 | from typing import Any |
4 | 4 |
|
5 | | -from rdkit.Chem import Mol, MolFromSmiles, MolToSmiles |
| 5 | +from rdkit.Chem import Mol, MolFromInchi, MolFromSmiles, MolToSmiles |
6 | 6 | from rdkit.Chem.PropertyMol import PropertyMol |
7 | 7 |
|
8 | 8 |
|
9 | 9 | def ensure_mols(X: Sequence[Any]) -> list[Mol]: |
10 | 10 | """ |
11 | 11 | Ensure that all input sequence elements are RDKit ``Mol`` objects. Requires |
12 | | - all input elements to be of the same type: string (SMILES strings) or ``Mol``. |
13 | | - In the case of SMILES strings, they are converted to RDKit ``Mol`` objects with |
| 12 | + all input elements to be of the same type: string (SMILES or InChI strings) or ``Mol``. |
| 13 | + In the case of SMILES or InChI strings, they are converted to RDKit ``Mol`` objects with |
14 | 14 | default settings. |
15 | 15 | """ |
16 | 16 | if not all(isinstance(x, (Mol, PropertyMol, str)) for x in X): |
17 | 17 | types = {type(x) for x in X} |
18 | 18 | raise TypeError( |
19 | | - f"Passed values must be RDKit Mol objects or SMILES strings, got types: {types}" |
| 19 | + f"Passed values must be RDKit Mol objects, SMILES or InChI strings, got types: {types}" |
20 | 20 | ) |
21 | 21 |
|
22 | | - mols = [MolFromSmiles(x) if isinstance(x, str) else x for x in X] |
| 22 | + if isinstance(X[0], str): |
| 23 | + parser = MolFromInchi if X[0].startswith("InChI=") else MolFromSmiles |
| 24 | + mols = [parser(x) for x in X] |
| 25 | + else: |
| 26 | + mols = list(X) |
23 | 27 |
|
24 | 28 | if any(x is None for x in mols): |
25 | 29 | idx = mols.index(None) |
|
0 commit comments