11from abc import ABC , abstractmethod
22from collections import deque
33from pathlib import Path
4- from typing import Any , Deque , Dict , Optional
4+ from typing import Any , Deque , Dict
55
66import pandas as pd
77import torch
8- from lightning import LightningModule
98
109from chebai .result .classification import print_metrics
1110
@@ -29,7 +28,7 @@ def __init__(
2928 self ,
3029 model_configs : Dict [str , Dict [str , Any ]],
3130 data_processed_dir_main : str ,
32- operation : str = EVAL_OP ,
31+ operation_mode : str = EVAL_OP ,
3332 ** kwargs : Any ,
3433 ) -> None :
3534 """
@@ -42,13 +41,13 @@ def __init__(
4241 """
4342 if bool (kwargs .get ("_perform_validation_checks" , True )):
4443 self ._perform_validation_checks (
45- model_configs , operation = operation , ** kwargs
44+ model_configs , operation = operation_mode , ** kwargs
4645 )
4746
4847 self ._model_configs : Dict [str , Dict [str , Any ]] = model_configs
4948 self ._data_processed_dir_main : str = data_processed_dir_main
50- self ._operation : str = operation
51- print (f"Ensemble operation: { self ._operation } " )
49+ self ._operation_mode : str = operation_mode
50+ print (f"Ensemble operation: { self ._operation_mode } " )
5251
5352 # These instance variable will be set in method `_process_input_to_ensemble`
5453 self ._total_data_size : int | None = None
@@ -126,7 +125,7 @@ def _perform_validation_checks(
126125 labels_set .add (model_labels_path )
127126
128127 def _process_input_to_ensemble (self , ** kwargs : Any ) -> list [str ] | Path :
129- if self ._operation == PRED_OP :
128+ if self ._operation_mode == PRED_OP :
130129 p = Path (kwargs ["smiles_list_file_path" ])
131130 smiles_list : list [str ] = []
132131 with open (p , "r" ) as f :
@@ -138,7 +137,7 @@ def _process_input_to_ensemble(self, **kwargs: Any) -> list[str] | Path:
138137 smiles_list .append (smiles )
139138 self ._total_data_size = len (smiles_list )
140139 return smiles_list
141- elif self ._operation == EVAL_OP :
140+ elif self ._operation_mode == EVAL_OP :
142141 processed_dir_path = Path (self ._data_processed_dir_main )
143142 data_pkl_path = processed_dir_path / "data.pkl"
144143 if not data_pkl_path .exists ():
@@ -183,7 +182,7 @@ def run_ensemble(self) -> None:
183182 )
184183
185184 print (
186- f"Running { self .__class__ .__name__ } ensemble for { self ._operation } operation..."
185+ f"Running { self .__class__ .__name__ } ensemble for { self ._operation_mode } operation..."
187186 )
188187 while self ._model_queue :
189188 model_name = self ._model_queue .popleft ()
@@ -204,7 +203,7 @@ def run_ensemble(self) -> None:
204203 true_scores = true_scores , false_scores = false_scores
205204 )
206205
207- if self ._operation == EVAL_OP :
206+ if self ._operation_mode == EVAL_OP :
208207 assert (
209208 self ._collated_labels is not None
210209 ), "Collated labels must be set for evaluation operation."
@@ -214,6 +213,31 @@ def run_ensemble(self) -> None:
214213 self ._device ,
215214 classes = list (self ._dm_labels .keys ()),
216215 )
216+ else :
217+ # Get SMILES and label names
218+ smiles_list = self ._ensemble_input
219+ label_names = list (self ._dm_labels .keys ())
220+ # Efficient conversion from tensor to NumPy
221+ preds_np = final_preds .detach ().cpu ().numpy ()
222+
223+ assert (
224+ len (smiles_list ) == preds_np .shape [0 ]
225+ ), "Length of SMILES list does not match number of predictions."
226+ assert (
227+ len (label_names ) == preds_np .shape [1 ]
228+ ), "Number of label names does not match number of predictions."
229+
230+ # Build DataFrame
231+ df = pd .DataFrame (preds_np , columns = label_names )
232+ df .insert (0 , "SMILES" , smiles_list )
233+
234+ # Save to CSV
235+ output_path = (
236+ Path (self ._data_processed_dir_main ) / "ensemble_predictions.csv"
237+ )
238+ df .to_csv (output_path , index = False )
239+
240+ print (f"Predictions saved to { output_path } " )
217241
218242 @abstractmethod
219243 def _controller (
0 commit comments