11from abc import ABC , abstractmethod
22from collections import deque
33from pathlib import Path
4- from typing import Any , Deque , Dict , Literal , Optional
4+ from typing import Any , Deque , Dict , Optional
55
66import pandas as pd
77import torch
@@ -38,7 +38,6 @@ def __init__(
3838 Args:
3939 model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations.
4040 data_processed_dir_main (str): Path to the processed data directory.
41- reader_dir_name (str): Name of the directory used by the reader. Defaults to 'smiles_token'.
4241 **kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'.
4342 """
4443 if bool (kwargs .get ("_perform_validation_checks" , True )):
@@ -51,16 +50,15 @@ def __init__(
5150 self ._operation : str = operation
5251 print (f"Ensemble operation: { self ._operation } " )
5352
54- self . _input_dim : Optional [ int ] = kwargs . get ( "input_dim" , None )
55- self ._total_data_size : int = None
53+ # These instance variable will be set in method `_process_input_to_ensemble`
54+ self ._total_data_size : int | None = None
5655 self ._ensemble_input : list [str ] | Path = self ._process_input_to_ensemble (
5756 ** kwargs
5857 )
5958 print (f"Total data size (data.pkl) is { self ._total_data_size } " )
6059
6160 self ._device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
6261
63- self ._models : Dict [str , LightningModule ] = {}
6462 self ._dm_labels : Dict [str , int ] = self ._load_data_module_labels ()
6563 self ._num_of_labels : int = len (self ._dm_labels )
6664 print (f"Number of labes for this data is { self ._num_of_labels } " )
@@ -69,6 +67,7 @@ def __init__(
6967 1 , self ._num_of_labels , device = self ._device
7068 )
7169 self ._model_queue : Deque [str ] = deque ()
70+ self ._collated_labels : torch .Tensor | None = None
7271
7372 @classmethod
7473 def _perform_validation_checks (
@@ -126,10 +125,10 @@ def _perform_validation_checks(
126125 class_set .add (model_class_path )
127126 labels_set .add (model_labels_path )
128127
129- def _process_input_to_ensemble (self , ** kwargs : any ) -> list [str ] | Path :
128+ def _process_input_to_ensemble (self , ** kwargs : Any ) -> list [str ] | Path :
130129 if self ._operation == PRED_OP :
131130 p = Path (kwargs ["smiles_list_file_path" ])
132- smiles_list = []
131+ smiles_list : list [ str ] = []
133132 with open (p , "r" ) as f :
134133 for line in f :
135134 # Skip empty or whitespace-only lines
@@ -140,11 +139,14 @@ def _process_input_to_ensemble(self, **kwargs: any) -> list[str] | Path:
140139 self ._total_data_size = len (smiles_list )
141140 return smiles_list
142141 elif self ._operation == EVAL_OP :
143- data_pkl_path = Path (self ._data_processed_dir_main ) / "data.pkl"
142+ processed_dir_path = Path (self ._data_processed_dir_main )
143+ data_pkl_path = processed_dir_path / "data.pkl"
144144 if not data_pkl_path .exists ():
145- raise FileNotFoundError ()
145+ raise FileNotFoundError (
146+ f"data.pkl does not exist in the { processed_dir_path } directory"
147+ )
146148 self ._total_data_size = len (pd .read_pickle (data_pkl_path ))
147- return p
149+ return processed_dir_path
148150 else :
149151 raise ValueError ("Invalid operation" )
150152
@@ -180,6 +182,9 @@ def run_ensemble(self) -> None:
180182 self ._total_data_size , self ._num_of_labels , device = self ._device
181183 )
182184
185+ print (
186+ f"Running { self .__class__ .__name__ } ensemble for { self ._operation } operation..."
187+ )
183188 while self ._model_queue :
184189 model_name = self ._model_queue .popleft ()
185190 print (f"Processing model: { model_name } " )
@@ -195,16 +200,17 @@ def run_ensemble(self) -> None:
195200 false_scores = false_scores ,
196201 )
197202
198- print (f"Consolidating predictions for { self .__class__ .__name__ } " )
199203 final_preds = self ._consolidate_on_finish (
200204 true_scores = true_scores , false_scores = false_scores
201205 )
202- print_metrics (
203- final_preds ,
204- self ._collated_data .y ,
205- self ._device ,
206- classes = list (self ._dm_labels .keys ()),
207- )
206+
207+ if self ._operation == EVAL_OP :
208+ print_metrics (
209+ final_preds ,
210+ self ._collated_labels ,
211+ self ._device ,
212+ classes = list (self ._dm_labels .keys ()),
213+ )
208214
209215 @abstractmethod
210216 def _controller (
0 commit comments