|
13 | 13 |
|
14 | 14 |
|
15 | 15 | class Predictor: |
16 | | - def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None): |
| 16 | + def __init__( |
| 17 | + self, |
| 18 | + checkpoint_path: _PATH, |
| 19 | + batch_size: Optional[int] = None, |
| 20 | + compile_model: bool = True, |
| 21 | + ): |
17 | 22 | """Initializes the Predictor with a model loaded from the checkpoint. |
18 | 23 |
|
19 | 24 | Args: |
20 | 25 | checkpoint_path: Path to the model checkpoint. |
21 | 26 | batch_size: Optional batch size for the DataLoader. If not provided, |
22 | 27 | the default from the datamodule will be used. |
| 28 | + compile_model: Whether to compile the model using torch.compile. Default is True. |
23 | 29 | """ |
24 | 30 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
25 | 31 | ckpt_file = torch.load( |
26 | 32 | checkpoint_path, map_location=self.device, weights_only=False |
27 | 33 | ) |
| 34 | + print("-" * 50) |
| 35 | + print(f"For Loaded checkpoint from: {checkpoint_path}") |
| 36 | + print("Below are the modules loaded from the checkpoint:") |
28 | 37 |
|
29 | 38 | self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] |
30 | 39 | # self._dm_hparams.pop("splits_file_path") |
31 | 40 | self._dm: XYBaseDataModule = instantiate_module( |
32 | 41 | XYBaseDataModule, self._dm_hparams |
33 | 42 | ) |
34 | | - print(f"Loaded datamodule class: {self._dm.__class__.__name__}") |
35 | 43 | if batch_size is not None and int(batch_size) > 0: |
36 | 44 | self._dm.batch_size = int(batch_size) |
| 45 | + print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}") |
37 | 46 |
|
38 | 47 | self._model_hparams = ckpt_file["hyper_parameters"] |
39 | 48 | self._model: ChebaiBaseNet = instantiate_module( |
40 | 49 | ChebaiBaseNet, self._model_hparams |
41 | 50 | ) |
| 51 | + print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") |
| 52 | + |
| 53 | + if compile_model: |
| 54 | + self._model = torch.compile(self._model) |
42 | 55 | self._model.eval() |
43 | | - # TODO: Enable torch.compile when supported |
44 | | - # model = torch.compile(model) |
45 | | - print(f"Loaded model class: {self._model.__class__.__name__}") |
| 56 | + print("-" * 50) |
46 | 57 |
|
47 | 58 | def predict_from_file( |
48 | 59 | self, |
|
0 commit comments