Skip to content

Commit a11b0b7

Browse files
committed
restructure error handling and update cache indexing for gnn
1 parent e8e4ec3 commit a11b0b7

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

chebifier/prediction_models/gnn_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def read_smiles(self, smiles):
6060
if isinstance(prop.encoder, IndexEncoder):
6161
if str(value) in prop.encoder.cache:
6262
index = (
63-
prop.encoder.cache.index(str(value)) + prop.encoder.offset
63+
prop.encoder.cache[str(value)] + prop.encoder.offset
6464
)
6565
else:
6666
index = 0

chebifier/prediction_models/nn_predictor.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,26 @@ def predict_smiles_list(self, smiles_list) -> list:
5757
could_not_parse = []
5858
index_map = dict()
5959
for i, smiles in enumerate(smiles_list):
60+
if not smiles:
61+
print(f"Model {self.model_name} received a missing SMILES string at position {i}.")
62+
could_not_parse.append(i)
63+
continue
6064
try:
61-
# Try to parse the smiles string
62-
if not smiles:
63-
raise ValueError()
6465
d = self.read_smiles(smiles)
66+
6567
# This is just for sanity checks
6668
rdmol = Chem.MolFromSmiles(smiles, sanitize=False)
67-
except Exception as e:
68-
# Note if it fails
69-
could_not_parse.append(i)
70-
print(f"Failing to parse {smiles} due to {e}")
71-
else:
7269
if rdmol is None:
70+
print(f"Model {self.model_name} received a SMILES string RDKit can't read at position {i}: {smiles}")
7371
could_not_parse.append(i)
74-
else:
75-
index_map[i] = len(token_dicts)
76-
token_dicts.append(d)
72+
continue
73+
except Exception as e:
74+
could_not_parse.append(i)
75+
print(f"Model {self.model_name} failed to parse a SMILES string at position {i}: {smiles}")
76+
index_map[i] = len(token_dicts)
77+
token_dicts.append(d)
7778
results = []
78-
if token_dicts:
79+
if len(token_dicts) > 0:
7980
for batch in tqdm.tqdm(
8081
self.batchify(token_dicts),
8182
desc=f"{self.model_name}",

0 commit comments

Comments
 (0)