@@ -20,28 +20,33 @@ def __init__(self):
2020 def predict_from_file (
2121 self ,
2222 checkpoint_path : _PATH ,
23- input_path : _PATH ,
23+ smiles_file_path : _PATH ,
2424 save_to : _PATH = "predictions.csv" ,
2525 classes_path : Optional [_PATH ] = None ,
26+ batch_size : Optional [int ] = None ,
2627 ) -> None :
2728 """
2829 Loads a model from a checkpoint and makes predictions on input data from a file.
2930
3031 Args:
31- model: The model to use for predictions.
3232 checkpoint_path: Path to the model checkpoint.
33- input_path : Path to the input file containing SMILES strings.
33+ smiles_file_path : Path to the input file containing SMILES strings.
3434 save_to: Path to save the predictions CSV file.
35- classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered).
35+ classes_path: Optional path to a file containing class names:
36+ if no class names are provided, code will try to get the class path
37+ from the datamodule, else the columns will be numbered.
38+ batch_size: Optional batch size for the DataLoader. If not provided,
39+ the default from the datamodule will be used.
3640 """
37- with open (input_path , "r" ) as input :
41+ with open (smiles_file_path , "r" ) as input :
3842 smiles_strings = [inp .strip () for inp in input .readlines ()]
3943
4044 self .predict_smiles (
4145 checkpoint_path ,
4246 smiles = smiles_strings ,
4347 classes_path = classes_path ,
4448 save_to = save_to ,
49+ batch_size = batch_size ,
4550 )
4651
4752 @torch .inference_mode ()
@@ -51,16 +56,24 @@ def predict_smiles(
5156 smiles : List [str ],
5257 classes_path : Optional [_PATH ] = None ,
5358 save_to : Optional [_PATH ] = None ,
59+ batch_size : Optional [int ] = None ,
5460 ** kwargs ,
5561 ) -> torch .Tensor | None :
5662 """
5763 Predicts the output for a list of SMILES strings using the model.
5864
5965 Args:
60- model: The model to use for predictions .
66+ checkpoint_path: Path to the model checkpoint .
6167 smiles: A list of SMILES strings.
62-
63- Returns:
68+ classes_path: Optional path to a file containing class names. If no class
69+ names are provided, code will try to get the class path from the datamodule,
70+ else the columns will be numbered.
71+ save_to: Optional path to save the predictions CSV file. If not provided,
72+ predictions will be returned as a tensor.
73+ batch_size: Optional batch size for the DataLoader. If not provided, the default
74+ from the datamodule will be used.
75+
76+ Returns: (if save_to is None)
6477 A tensor containing the predictions.
6578 """
6679 ckpt_file = torch .load (
@@ -71,10 +84,13 @@ def predict_smiles(
7184 dm_hparams .pop ("splits_file_path" )
7285 dm : XYBaseDataModule = instantiate_module (XYBaseDataModule , dm_hparams )
7386 print (f"Loaded datamodule class: { dm .__class__ .__name__ } " )
87+ if batch_size is not None :
88+ dm .batch_size = batch_size
7489
7590 model_hparams = ckpt_file ["hyper_parameters" ]
7691 model : ChebaiBaseNet = instantiate_module (ChebaiBaseNet , model_hparams )
7792 model .eval ()
93+ # TODO: Enable torch.compile when supported
7894 # model = torch.compile(model)
7995 print (f"Loaded model class: { model .__class__ .__name__ } " )
8096
0 commit comments