-
Notifications
You must be signed in to change notification settings - Fork 14
Open
Description
The original code from the notebook
def predict(self, smiles_list):
unimol_input = self.preprocess_data(smiles_list)
dataset = MolDataset(unimol_input)
dataloader = DataLoader(dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=self.model.batch_collate_fn,
)
results = {}
for batch in dataloader:
net_input, _ = self.decorate_torch_batch(batch)
with torch.no_grad():
predictions = self.model(**net_input)
for smiles, energy in zip(smiles_list, predictions):
results[smiles] = energy.item()
return resultsThe issue is that zip is called for smiles_list which is a whole list of SMILES, and predictions, which is only a chunk of a batch_size (32 by default). Two iterables have different length and finally in the results there will be only 32 first SMILES with wrong energies from the last batch. The code works properly only if the smiles_list is smaller or equal to the batch_size.
A possible fix
def predict(self, smiles_list):
unimol_input = self.preprocess_data(smiles_list)
dataset = MolDataset(unimol_input)
dataloader = DataLoader(dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=self.model.batch_collate_fn,
)
all_energies = []
for i, batch in enumerate(dataloader):
net_input, _ = self.decorate_torch_batch(batch)
with torch.no_grad():
predictions = self.model(**net_input)
all_energies.extend((energy.item()for energy in predictions))
results = dict(zip(smiles_list, all_energies))
return resultsReactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels