Skip to content

Commit 54328f1

Browse files
committed
Fix model device handling in CustomTrainer
1 parent 0dce958 commit 54328f1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

chebai/trainer/CustomTrainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def predict_from_file(
3434
smiles_strings = [inp.strip() for inp in input.readlines()]
3535
loaded_model.eval()
3636
predictions = self._predict_smiles(loaded_model, smiles_strings)
37-
predictions_df = pd.DataFrame(predictions.detach().numpy())
37+
predictions_df = pd.DataFrame(predictions.detach().cpu().numpy())
3838
if classes_path is not None:
3939
with open(classes_path, "r") as f:
4040
predictions_df.columns = [cls.strip() for cls in f.readlines()]
@@ -44,7 +44,7 @@ def predict_from_file(
4444
def _predict_smiles(self, model: LightningModule, smiles: List[str]):
4545
reader = ChemDataReader()
4646
parsed_smiles = [reader._read_data(s) for s in smiles]
47-
x = pad_sequence([torch.tensor(a) for a in parsed_smiles], batch_first=True)
47+
x = pad_sequence([torch.tensor(a, device=model.device) for a in parsed_smiles], batch_first=True)
4848
cls_tokens = (
4949
torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1)
5050
* CLS_TOKEN

0 commit comments

Comments
 (0)