Skip to content

Commit ba5884a

Browse files
committed
compile model
1 parent d906ad4 commit ba5884a

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

chebai/result/prediction.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,47 @@
1313

1414

1515
class Predictor:
16-
def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None):
16+
def __init__(
17+
self,
18+
checkpoint_path: _PATH,
19+
batch_size: Optional[int] = None,
20+
compile_model: bool = True,
21+
):
1722
"""Initializes the Predictor with a model loaded from the checkpoint.
1823
1924
Args:
2025
checkpoint_path: Path to the model checkpoint.
2126
batch_size: Optional batch size for the DataLoader. If not provided,
2227
the default from the datamodule will be used.
28+
compile_model: Whether to compile the model using torch.compile. Default is True.
2329
"""
2430
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2531
ckpt_file = torch.load(
2632
checkpoint_path, map_location=self.device, weights_only=False
2733
)
34+
print("-" * 50)
35+
print(f"For Loaded checkpoint from: {checkpoint_path}")
36+
print("Below are the modules loaded from the checkpoint:")
2837

2938
self._dm_hparams = ckpt_file["datamodule_hyper_parameters"]
3039
# self._dm_hparams.pop("splits_file_path")
3140
self._dm: XYBaseDataModule = instantiate_module(
3241
XYBaseDataModule, self._dm_hparams
3342
)
34-
print(f"Loaded datamodule class: {self._dm.__class__.__name__}")
3543
if batch_size is not None and int(batch_size) > 0:
3644
self._dm.batch_size = int(batch_size)
45+
print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}")
3746

3847
self._model_hparams = ckpt_file["hyper_parameters"]
3948
self._model: ChebaiBaseNet = instantiate_module(
4049
ChebaiBaseNet, self._model_hparams
4150
)
51+
print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}")
52+
53+
if compile_model:
54+
self._model = torch.compile(self._model)
4255
self._model.eval()
43-
# TODO: Enable torch.compile when supported
44-
# model = torch.compile(model)
45-
print(f"Loaded model class: {self._model.__class__.__name__}")
56+
print("-" * 50)
4657

4758
def predict_from_file(
4859
self,

0 commit comments

Comments
 (0)