|
13 | 13 |
|
14 | 14 |
|
15 | 15 | class Predictor: |
16 | | - def __init__(self): |
| 16 | + def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None): |
| 17 | + """Initializes the Predictor with a model loaded from the checkpoint. |
| 18 | +
|
| 19 | + Args: |
| 20 | + checkpoint_path: Path to the model checkpoint. |
| 21 | + batch_size: Optional batch size for the DataLoader. If not provided, |
| 22 | + the default from the datamodule will be used. |
| 23 | + """ |
17 | 24 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
18 | | - super().__init__() |
| 25 | + ckpt_file = torch.load( |
| 26 | + checkpoint_path, map_location=self.device, weights_only=False |
| 27 | + ) |
| 28 | + |
| 29 | + self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] |
| 30 | + self._dm_hparams.pop("splits_file_path") |
| 31 | + self._dm: XYBaseDataModule = instantiate_module( |
| 32 | + XYBaseDataModule, self._dm_hparams |
| 33 | + ) |
| 34 | + print(f"Loaded datamodule class: {self._dm.__class__.__name__}") |
| 35 | + if batch_size is not None and int(batch_size) > 0: |
| 36 | + self._dm.batch_size = int(batch_size) |
| 37 | + |
| 38 | + self._model_hparams = ckpt_file["hyper_parameters"] |
| 39 | + self._model: ChebaiBaseNet = instantiate_module( |
| 40 | + ChebaiBaseNet, self._model_hparams |
| 41 | + ) |
| 42 | + self._model.eval() |
| 43 | + # TODO: Enable torch.compile when supported |
| 44 | + # model = torch.compile(model) |
| 45 | + print(f"Loaded model class: {self._model.__class__.__name__}") |
19 | 46 |
|
20 | 47 | def predict_from_file( |
21 | 48 | self, |
22 | | - checkpoint_path: _PATH, |
23 | 49 | smiles_file_path: _PATH, |
24 | 50 | save_to: _PATH = "predictions.csv", |
25 | 51 | classes_path: Optional[_PATH] = None, |
26 | | - batch_size: Optional[int] = None, |
27 | 52 | ) -> None: |
28 | 53 | """ |
29 | 54 | Loads a model from a checkpoint and makes predictions on input data from a file. |
30 | 55 |
|
31 | 56 | Args: |
32 | | - checkpoint_path: Path to the model checkpoint. |
33 | 57 | smiles_file_path: Path to the input file containing SMILES strings. |
34 | 58 | save_to: Path to save the predictions CSV file. |
35 | 59 | classes_path: Optional path to a file containing class names: |
36 | 60 | if no class names are provided, code will try to get the class path |
37 | 61 | from the datamodule, else the columns will be numbered. |
38 | | - batch_size: Optional batch size for the DataLoader. If not provided, |
39 | | - the default from the datamodule will be used. |
40 | 62 | """ |
41 | 63 | with open(smiles_file_path, "r") as input: |
42 | 64 | smiles_strings = [inp.strip() for inp in input.readlines()] |
43 | 65 |
|
44 | | - self.predict_smiles( |
45 | | - checkpoint_path, |
| 66 | + preds: torch.Tensor = self.predict_smiles( |
46 | 67 | smiles=smiles_strings, |
47 | 68 | classes_path=classes_path, |
48 | 69 | save_to=save_to, |
49 | | - batch_size=batch_size, |
50 | 70 | ) |
51 | 71 |
|
| 72 | + predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) |
| 73 | + |
| 74 | + def _add_class_columns(class_file_path: _PATH): |
| 75 | + with open(class_file_path, "r") as f: |
| 76 | + predictions_df.columns = [cls.strip() for cls in f.readlines()] |
| 77 | + |
| 78 | + if classes_path is not None: |
| 79 | + _add_class_columns(classes_path) |
| 80 | + elif os.path.exists(self._dm.classes_txt_file_path): |
| 81 | + _add_class_columns(self._dm.classes_txt_file_path) |
| 82 | + |
| 83 | + predictions_df.index = smiles_strings |
| 84 | + predictions_df.to_csv(save_to) |
| 85 | + |
52 | 86 | @torch.inference_mode() |
53 | 87 | def predict_smiles( |
54 | 88 | self, |
55 | | - checkpoint_path: _PATH, |
56 | 89 | smiles: List[str], |
57 | | - classes_path: Optional[_PATH] = None, |
58 | | - save_to: Optional[_PATH] = None, |
59 | | - batch_size: Optional[int] = None, |
60 | | - **kwargs, |
61 | | - ) -> torch.Tensor | None: |
| 90 | + ) -> torch.Tensor: |
62 | 91 | """ |
63 | 92 | Predicts the output for a list of SMILES strings using the model. |
64 | 93 |
|
65 | 94 | Args: |
66 | | - checkpoint_path: Path to the model checkpoint. |
67 | 95 | smiles: A list of SMILES strings. |
68 | | - classes_path: Optional path to a file containing class names. If no class |
69 | | - names are provided, code will try to get the class path from the datamodule, |
70 | | - else the columns will be numbered. |
71 | | - save_to: Optional path to save the predictions CSV file. If not provided, |
72 | | - predictions will be returned as a tensor. |
73 | | - batch_size: Optional batch size for the DataLoader. If not provided, the default |
74 | | - from the datamodule will be used. |
75 | | -
|
76 | | - Returns: (if save_to is None) |
| 96 | +
|
| 97 | + Returns: |
77 | 98 | A tensor containing the predictions. |
78 | 99 | """ |
79 | | - ckpt_file = torch.load( |
80 | | - checkpoint_path, map_location=self.device, weights_only=False |
81 | | - ) |
82 | | - |
83 | | - dm_hparams = ckpt_file["datamodule_hyper_parameters"] |
84 | | - dm_hparams.pop("splits_file_path") |
85 | | - dm: XYBaseDataModule = instantiate_module(XYBaseDataModule, dm_hparams) |
86 | | - print(f"Loaded datamodule class: {dm.__class__.__name__}") |
87 | | - if batch_size is not None: |
88 | | - dm.batch_size = batch_size |
89 | | - |
90 | | - model_hparams = ckpt_file["hyper_parameters"] |
91 | | - model: ChebaiBaseNet = instantiate_module(ChebaiBaseNet, model_hparams) |
92 | | - model.eval() |
93 | | - # TODO: Enable torch.compile when supported |
94 | | - # model = torch.compile(model) |
95 | | - print(f"Loaded model class: {model.__class__.__name__}") |
96 | | - |
97 | 100 | # For certain data prediction piplines, we may need model hyperparameters |
98 | | - pred_dl: DataLoader = dm.predict_dataloader( |
99 | | - smiles_list=smiles, model_hparams=model_hparams |
| 101 | + pred_dl: DataLoader = self._dm.predict_dataloader( |
| 102 | + smiles_list=smiles, model_hparams=self._model_hparams |
100 | 103 | ) |
101 | 104 |
|
102 | 105 | preds = [] |
103 | 106 | for batch_idx, batch in enumerate(pred_dl): |
104 | 107 | # For certain model prediction pipelines, we may need data module hyperparameters |
105 | | - preds.append(model.predict_step(batch, batch_idx, dm_hparams=dm_hparams)) |
| 108 | + preds.append( |
| 109 | + self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams) |
| 110 | + ) |
106 | 111 |
|
107 | | - if not save_to: |
108 | | - # If no save path is provided, return the predictions tensor |
109 | | - return torch.cat(preds) |
| 112 | + return torch.cat(preds) |
110 | 113 |
|
111 | | - predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) |
112 | | - |
113 | | - def _add_class_columns(class_file_path: _PATH): |
114 | | - with open(class_file_path, "r") as f: |
115 | | - predictions_df.columns = [cls.strip() for cls in f.readlines()] |
116 | 114 |
|
117 | | - if classes_path is not None: |
118 | | - _add_class_columns(classes_path) |
119 | | - elif os.path.exists(dm.classes_txt_file_path): |
120 | | - _add_class_columns(dm.classes_txt_file_path) |
| 115 | +class MainPredictor: |
| 116 | + @staticmethod |
| 117 | + def predict_from_file( |
| 118 | + checkpoint_path: _PATH, |
| 119 | + smiles_file_path: _PATH, |
| 120 | + save_to: _PATH = "predictions.csv", |
| 121 | + classes_path: Optional[_PATH] = None, |
| 122 | + batch_size: Optional[int] = None, |
| 123 | + ) -> None: |
| 124 | + predictor = Predictor(checkpoint_path, batch_size) |
| 125 | + predictor.predict_from_file( |
| 126 | + smiles_file_path, |
| 127 | + save_to, |
| 128 | + classes_path, |
| 129 | + ) |
121 | 130 |
|
122 | | - predictions_df.index = smiles |
123 | | - predictions_df.to_csv(save_to) |
| 131 | + @staticmethod |
| 132 | + def predict_smiles( |
| 133 | + checkpoint_path: _PATH, |
| 134 | + smiles: List[str], |
| 135 | + batch_size: Optional[int] = None, |
| 136 | + ) -> torch.Tensor: |
| 137 | + predictor = Predictor(checkpoint_path, batch_size) |
| 138 | + return predictor.predict_smiles(smiles) |
124 | 139 |
|
125 | 140 |
|
126 | 141 | if __name__ == "__main__": |
127 | 142 | # python chebai/result/prediction.py predict_from_file --help |
128 | | - CLI(Predictor, as_positional=False) |
| 143 | + # python chebai/result/prediction.py predict_smiles --help |
| 144 | + CLI(MainPredictor, as_positional=False) |
0 commit comments