File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -34,7 +34,7 @@ def predict_from_file(
3434 smiles_strings = [inp .strip () for inp in input .readlines ()]
3535 loaded_model .eval ()
3636 predictions = self ._predict_smiles (loaded_model , smiles_strings )
37- predictions_df = pd .DataFrame (predictions .detach ().numpy ())
37+ predictions_df = pd .DataFrame (predictions .detach ().cpu (). numpy ())
3838 if classes_path is not None :
3939 with open (classes_path , "r" ) as f :
4040 predictions_df .columns = [cls .strip () for cls in f .readlines ()]
@@ -44,7 +44,7 @@ def predict_from_file(
4444 def _predict_smiles (self , model : LightningModule , smiles : List [str ]):
4545 reader = ChemDataReader ()
4646 parsed_smiles = [reader ._read_data (s ) for s in smiles ]
47- x = pad_sequence ([torch .tensor (a ) for a in parsed_smiles ], batch_first = True )
47+ x = pad_sequence ([torch .tensor (a , device = model . device ) for a in parsed_smiles ], batch_first = True )
4848 cls_tokens = (
4949 torch .ones (x .shape [0 ], dtype = torch .int , device = model .device ).unsqueeze (- 1 )
5050 * CLS_TOKEN
You can’t perform that action at this time.
0 commit comments