Skip to content

Commit e85a9c1

Browse files
committed
add canonicalize flag
1 parent 04fd197 commit e85a9c1

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

chebai/preprocessing/reader.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,11 @@ class ChemDataReader(TokenIndexerReader):
177177

178178
COLLATOR = RaggedCollator
179179

180+
def __init__(self, canonicalize_smiles=True, *args, **kwargs) -> None:
181+
super().__init__(*args, **kwargs)
182+
self.canonicalize_smiles = canonicalize_smiles
183+
print(f"Using SMILES canonicalization: {self.canonicalize_smiles}")
184+
180185
@classmethod
181186
def name(cls) -> str:
182187
"""Returns the name of the data reader."""
@@ -192,13 +197,14 @@ def _read_data(self, raw_data: str) -> List[int]:
192197
Returns:
193198
List[int]: A list of integers representing the indices of the SMILES tokens.
194199
"""
195-
try:
196-
mol = Chem.MolFromSmiles(raw_data.strip())
197-
if mol is not None:
198-
raw_data = Chem.MolToSmiles(mol, canonical=True)
199-
except Exception as e:
200-
print(f"RDKit failed to process {raw_data}")
201-
print(f"\t{e}")
200+
if self.canonicalize_smiles:
201+
try:
202+
mol = Chem.MolFromSmiles(raw_data.strip())
203+
if mol is not None:
204+
raw_data = Chem.MolToSmiles(mol, canonical=True)
205+
except Exception as e:
206+
print(f"RDKit failed to process {raw_data}")
207+
print(f"\t{e}")
202208

203209
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
204210

0 commit comments

Comments
 (0)