Skip to content

Commit c6e8b61

Browse files
committed
modify pred logic to store model and dm as instance var
1 parent 31b12db commit c6e8b61

File tree

1 file changed

+78
-62
lines changed

1 file changed

+78
-62
lines changed

chebai/result/prediction.py

Lines changed: 78 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -13,116 +13,132 @@
1313

1414

1515
class Predictor:
16-
def __init__(self):
16+
def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None):
17+
"""Initializes the Predictor with a model loaded from the checkpoint.
18+
19+
Args:
20+
checkpoint_path: Path to the model checkpoint.
21+
batch_size: Optional batch size for the DataLoader. If not provided,
22+
the default from the datamodule will be used.
23+
"""
1724
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18-
super().__init__()
25+
ckpt_file = torch.load(
26+
checkpoint_path, map_location=self.device, weights_only=False
27+
)
28+
29+
self._dm_hparams = ckpt_file["datamodule_hyper_parameters"]
30+
self._dm_hparams.pop("splits_file_path")
31+
self._dm: XYBaseDataModule = instantiate_module(
32+
XYBaseDataModule, self._dm_hparams
33+
)
34+
print(f"Loaded datamodule class: {self._dm.__class__.__name__}")
35+
if batch_size is not None and int(batch_size) > 0:
36+
self._dm.batch_size = int(batch_size)
37+
38+
self._model_hparams = ckpt_file["hyper_parameters"]
39+
self._model: ChebaiBaseNet = instantiate_module(
40+
ChebaiBaseNet, self._model_hparams
41+
)
42+
self._model.eval()
43+
# TODO: Enable torch.compile when supported
44+
# model = torch.compile(model)
45+
print(f"Loaded model class: {self._model.__class__.__name__}")
1946

2047
def predict_from_file(
2148
self,
22-
checkpoint_path: _PATH,
2349
smiles_file_path: _PATH,
2450
save_to: _PATH = "predictions.csv",
2551
classes_path: Optional[_PATH] = None,
26-
batch_size: Optional[int] = None,
2752
) -> None:
2853
"""
2954
Loads a model from a checkpoint and makes predictions on input data from a file.
3055
3156
Args:
32-
checkpoint_path: Path to the model checkpoint.
3357
smiles_file_path: Path to the input file containing SMILES strings.
3458
save_to: Path to save the predictions CSV file.
3559
classes_path: Optional path to a file containing class names:
3660
if no class names are provided, code will try to get the class path
3761
from the datamodule, else the columns will be numbered.
38-
batch_size: Optional batch size for the DataLoader. If not provided,
39-
the default from the datamodule will be used.
4062
"""
4163
with open(smiles_file_path, "r") as input:
4264
smiles_strings = [inp.strip() for inp in input.readlines()]
4365

44-
self.predict_smiles(
45-
checkpoint_path,
66+
preds: torch.Tensor = self.predict_smiles(
4667
smiles=smiles_strings,
4768
classes_path=classes_path,
4869
save_to=save_to,
49-
batch_size=batch_size,
5070
)
5171

72+
predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy())
73+
74+
def _add_class_columns(class_file_path: _PATH):
75+
with open(class_file_path, "r") as f:
76+
predictions_df.columns = [cls.strip() for cls in f.readlines()]
77+
78+
if classes_path is not None:
79+
_add_class_columns(classes_path)
80+
elif os.path.exists(self._dm.classes_txt_file_path):
81+
_add_class_columns(self._dm.classes_txt_file_path)
82+
83+
predictions_df.index = smiles_strings
84+
predictions_df.to_csv(save_to)
85+
5286
@torch.inference_mode()
5387
def predict_smiles(
5488
self,
55-
checkpoint_path: _PATH,
5689
smiles: List[str],
57-
classes_path: Optional[_PATH] = None,
58-
save_to: Optional[_PATH] = None,
59-
batch_size: Optional[int] = None,
60-
**kwargs,
61-
) -> torch.Tensor | None:
90+
) -> torch.Tensor:
6291
"""
6392
Predicts the output for a list of SMILES strings using the model.
6493
6594
Args:
66-
checkpoint_path: Path to the model checkpoint.
6795
smiles: A list of SMILES strings.
68-
classes_path: Optional path to a file containing class names. If no class
69-
names are provided, code will try to get the class path from the datamodule,
70-
else the columns will be numbered.
71-
save_to: Optional path to save the predictions CSV file. If not provided,
72-
predictions will be returned as a tensor.
73-
batch_size: Optional batch size for the DataLoader. If not provided, the default
74-
from the datamodule will be used.
75-
76-
Returns: (if save_to is None)
96+
97+
Returns:
7798
A tensor containing the predictions.
7899
"""
79-
ckpt_file = torch.load(
80-
checkpoint_path, map_location=self.device, weights_only=False
81-
)
82-
83-
dm_hparams = ckpt_file["datamodule_hyper_parameters"]
84-
dm_hparams.pop("splits_file_path")
85-
dm: XYBaseDataModule = instantiate_module(XYBaseDataModule, dm_hparams)
86-
print(f"Loaded datamodule class: {dm.__class__.__name__}")
87-
if batch_size is not None:
88-
dm.batch_size = batch_size
89-
90-
model_hparams = ckpt_file["hyper_parameters"]
91-
model: ChebaiBaseNet = instantiate_module(ChebaiBaseNet, model_hparams)
92-
model.eval()
93-
# TODO: Enable torch.compile when supported
94-
# model = torch.compile(model)
95-
print(f"Loaded model class: {model.__class__.__name__}")
96-
97100
# For certain data prediction piplines, we may need model hyperparameters
98-
pred_dl: DataLoader = dm.predict_dataloader(
99-
smiles_list=smiles, model_hparams=model_hparams
101+
pred_dl: DataLoader = self._dm.predict_dataloader(
102+
smiles_list=smiles, model_hparams=self._model_hparams
100103
)
101104

102105
preds = []
103106
for batch_idx, batch in enumerate(pred_dl):
104107
# For certain model prediction pipelines, we may need data module hyperparameters
105-
preds.append(model.predict_step(batch, batch_idx, dm_hparams=dm_hparams))
108+
preds.append(
109+
self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams)
110+
)
106111

107-
if not save_to:
108-
# If no save path is provided, return the predictions tensor
109-
return torch.cat(preds)
112+
return torch.cat(preds)
110113

111-
predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy())
112-
113-
def _add_class_columns(class_file_path: _PATH):
114-
with open(class_file_path, "r") as f:
115-
predictions_df.columns = [cls.strip() for cls in f.readlines()]
116114

117-
if classes_path is not None:
118-
_add_class_columns(classes_path)
119-
elif os.path.exists(dm.classes_txt_file_path):
120-
_add_class_columns(dm.classes_txt_file_path)
115+
class MainPredictor:
116+
@staticmethod
117+
def predict_from_file(
118+
checkpoint_path: _PATH,
119+
smiles_file_path: _PATH,
120+
save_to: _PATH = "predictions.csv",
121+
classes_path: Optional[_PATH] = None,
122+
batch_size: Optional[int] = None,
123+
) -> None:
124+
predictor = Predictor(checkpoint_path, batch_size)
125+
predictor.predict_from_file(
126+
smiles_file_path,
127+
save_to,
128+
classes_path,
129+
)
121130

122-
predictions_df.index = smiles
123-
predictions_df.to_csv(save_to)
131+
@staticmethod
132+
def predict_smiles(
133+
checkpoint_path: _PATH,
134+
smiles: List[str],
135+
batch_size: Optional[int] = None,
136+
) -> torch.Tensor:
137+
predictor = Predictor(checkpoint_path, batch_size)
138+
return predictor.predict_smiles(smiles)
124139

125140

126141
if __name__ == "__main__":
127142
# python chebai/result/prediction.py predict_from_file --help
128-
CLI(Predictor, as_positional=False)
143+
# python chebai/result/prediction.py predict_smiles --help
144+
CLI(MainPredictor, as_positional=False)

0 commit comments

Comments
 (0)