Skip to content

Commit ac8cf63

Browse files
authored
Merge pull request #118 from ChEB-AI/feature/canonicalise-smiles
add smiles canonicalisation, update tokens.txt
2 parents 48725dd + e85a9c1 commit ac8cf63

File tree

2 files changed

+181
-1
lines changed

2 files changed

+181
-1
lines changed

chebai/preprocessing/bin/smiles_token/tokens.txt

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,3 +819,168 @@ p
819819
[16N]
820820
[17N]
821821
[14N]
822+
[Pb+2]
823+
[AlH4-]
824+
[BH4-]
825+
[Pt-2]
826+
[Cl+2]
827+
[I+3]
828+
[Br+2]
829+
[Cl+3]
830+
[Os-2]
831+
[Cr-2]
832+
[Hg-2]
833+
[PH]
834+
[Br+3]
835+
[I+2]
836+
[AsH2]
837+
[SH]
838+
[W-2]
839+
[Cd-2]
840+
[Ir-2]
841+
[Ru-2]
842+
[Rh-2]
843+
[Ag-2]
844+
[Be-2]
845+
[TeH2+]
846+
[13c]
847+
[13cH]
848+
[PH4]
849+
[AsH4]
850+
[As-2]
851+
[SbH3+]
852+
[SbH4]
853+
[BiH3]
854+
[BH3-]
855+
[GeH3]
856+
[GeH2]
857+
[SiH2-]
858+
[SiH2+]
859+
[SnH2]
860+
[SnH3]
861+
[SnH]
862+
[PbH]
863+
[PbH3]
864+
[Al-2]
865+
[B+2]
866+
[N+2]
867+
[SbH]
868+
[SbH2]
869+
[InH2]
870+
[GaH2]
871+
[TlH2]
872+
[Au+2]
873+
[sH+]
874+
[Hg+2]
875+
[Si-2]
876+
[Sn-2]
877+
[Pb-2]
878+
[AsH3]
879+
[Cr+2]
880+
[Ag+2]
881+
[V-2]
882+
[Ce-2]
883+
[13C@]
884+
[*+2]
885+
[He+2]
886+
[4He+2]
887+
[3He+2]
888+
[Eu+2]
889+
[Ge+2]
890+
[Os+2]
891+
[Y+2]
892+
[Gd+2]
893+
[La+2]
894+
[Se+2]
895+
[NH-2]
896+
[TeH2-]
897+
[AlH3-]
898+
[SbH3-]
899+
[AsH3-]
900+
[BiH3-]
901+
[PH3-]
902+
[CH2-2]
903+
[AsH4+]
904+
[AlH3+]
905+
[BiH3+]
906+
[FH+]
907+
[CH3+]
908+
[Te-2]
909+
[OH]
910+
[CH3]
911+
[18OH2]
912+
[OH3+]
913+
[OH4+2]
914+
[SH3]
915+
[SH3+]
916+
[SH3-]
917+
[SH4]
918+
[SeH2]
919+
[SeH-]
920+
[SeH3+]
921+
[SeH3-]
922+
[SeH3]
923+
[SeH+]
924+
[TeH2]
925+
[TeH-]
926+
[TeH3-]
927+
[TeH3+]
928+
[TeH+]
929+
[TeH3]
930+
[TeH4]
931+
[PoH2]
932+
[NH2]
933+
[NH+2]
934+
[PH5]
935+
[PH4+]
936+
[PH-2]
937+
[PH4-]
938+
[PH+2]
939+
[AsH2+]
940+
[AsH2-]
941+
[AsH+2]
942+
[AsH-2]
943+
[AsH5]
944+
[SbH3]
945+
[SbH4+]
946+
[SbH5]
947+
[BiH4+]
948+
[BiH5]
949+
[BiH4-]
950+
[BH2]
951+
[BH2+]
952+
[BH2-]
953+
[BH-2]
954+
[BH+2]
955+
[GeH4]
956+
[GeH3+]
957+
[GeH3-]
958+
[SiH3-]
959+
[SiH3+]
960+
[SiH+]
961+
[SiH4]
962+
[HeH+2]
963+
[HeH+]
964+
[AlH]
965+
[AlH+]
966+
[SnH4]
967+
[SnH3-]
968+
[SnH3+]
969+
[PbH4]
970+
[PbH3-]
971+
[PbH3+]
972+
[BeH4-2]
973+
[BeH]
974+
[BeH+]
975+
[BeH-]
976+
[BeH2]
977+
[AtH]
978+
[InH3]
979+
[GaH3]
980+
[TlH3]
981+
[IH3]
982+
[FeH6-4]
983+
[FH2+]
984+
[ClH2+]
985+
[BrH2+]
986+
[IH2+]

chebai/preprocessing/reader.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import deepsmiles
99
import selfies as sf
1010
from pysmiles.read_smiles import _tokenize
11+
from rdkit import Chem
1112
from transformers import RobertaTokenizerFast
1213

1314
from chebai.preprocessing.collate import DefaultCollator, RaggedCollator
@@ -176,21 +177,35 @@ class ChemDataReader(TokenIndexerReader):
176177

177178
COLLATOR = RaggedCollator
178179

180+
def __init__(self, canonicalize_smiles=True, *args, **kwargs) -> None:
181+
super().__init__(*args, **kwargs)
182+
self.canonicalize_smiles = canonicalize_smiles
183+
print(f"Using SMILES canonicalization: {self.canonicalize_smiles}")
184+
179185
@classmethod
180186
def name(cls) -> str:
181187
"""Returns the name of the data reader."""
182188
return "smiles_token"
183189

184190
def _read_data(self, raw_data: str) -> List[int]:
185191
"""
186-
Reads and tokenizes raw SMILES data into a list of token indices.
192+
Reads and tokenizes raw SMILES data into a list of token indices. Canonicalizes the SMILES string using RDKit.
187193
188194
Args:
189195
raw_data (str): The raw SMILES string to be tokenized.
190196
191197
Returns:
192198
List[int]: A list of integers representing the indices of the SMILES tokens.
193199
"""
200+
if self.canonicalize_smiles:
201+
try:
202+
mol = Chem.MolFromSmiles(raw_data.strip())
203+
if mol is not None:
204+
raw_data = Chem.MolToSmiles(mol, canonical=True)
205+
except Exception as e:
206+
print(f"RDKit failed to process {raw_data}")
207+
print(f"\t{e}")
208+
194209
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
195210

196211

0 commit comments

Comments
 (0)