Skip to content

Bug in FreeEnergyPredictor.predict #11

@DrrDom

Description

@DrrDom

The original code from the notebook

  def predict(self, smiles_list):
      unimol_input = self.preprocess_data(smiles_list)
      dataset = MolDataset(unimol_input)
      dataloader = DataLoader(dataset, 
                              batch_size=self.batch_size, 
                              shuffle=False,
                              collate_fn=self.model.batch_collate_fn,
                              )

      results = {}
      for batch in dataloader:
          net_input, _ = self.decorate_torch_batch(batch)
          with torch.no_grad():
              predictions = self.model(**net_input)
              for smiles, energy in zip(smiles_list, predictions):
                  results[smiles] = energy.item()
      return results

The issue is that zip is called for smiles_list which is a whole list of SMILES, and predictions, which is only a chunk of a batch_size (32 by default). Two iterables have different length and finally in the results there will be only 32 first SMILES with wrong energies from the last batch. The code works properly only if the smiles_list is smaller or equal to the batch_size.

A possible fix

  def predict(self, smiles_list):
      unimol_input = self.preprocess_data(smiles_list)
      dataset = MolDataset(unimol_input)
      dataloader = DataLoader(dataset,
                              batch_size=self.batch_size,
                              shuffle=False,
                              collate_fn=self.model.batch_collate_fn,
                              )

      all_energies = []
      for i, batch in enumerate(dataloader):
          net_input, _ = self.decorate_torch_batch(batch)
          with torch.no_grad():
              predictions = self.model(**net_input)
              all_energies.extend((energy.item()for energy in predictions))
      results = dict(zip(smiles_list, all_energies))
      return results

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions