1- import importlib
2- import json
3- import os
41from abc import ABC , abstractmethod
52from collections import deque
6- from typing import Any , Deque , Dict , Optional , Tuple
3+ from pathlib import Path
4+ from typing import Any , Deque , Dict , Optional
75
6+ import pandas as pd
87import torch
98from lightning import LightningModule
109
11- from chebai .models import ChebaiBaseNet
12- from chebai .preprocessing .structures import XYData
1310from chebai .result .classification import print_metrics
1411
15- from ._constants import (
16- MODEL_CKPT_PATH ,
17- MODEL_CLS_PATH ,
18- MODEL_LBL_PATH ,
19- READER_CLS_PATH ,
20- WRAPPER_CLS_PATH ,
21- )
22- from ._utils import _load_class
23- from ._wrappers import BaseWrapper
12+ from ._constants import MODEL_CLS_PATH , MODEL_LBL_PATH , WRAPPER_CLS_PATH
2413
2514
2615class EnsembleBase (ABC ):
@@ -33,38 +22,45 @@ class EnsembleBase(ABC):
3322 def __init__ (
3423 self ,
3524 model_configs : Dict [str , Dict [str , Any ]],
36- data_processed_dir_main : str ,
25+ data_file_path : str ,
26+ classes_file_path : str ,
3727 ** kwargs : Any ,
3828 ) -> None :
3929 """
4030 Initializes the ensemble model and loads configurations, labels, and sets up the environment.
4131
4232 Args:
4333 model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations.
44- data_processed_dir_main (str): Path to the processed data directory.
34+ data_file_path (str): Path to the processed data directory.
4535 reader_dir_name (str): Name of the directory used by the reader. Defaults to 'smiles_token'.
4636 **kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'.
4737 """
4838 if bool (kwargs .get ("_validate_configs" , True )):
4939 self ._validate_model_configs (model_configs )
5040
5141 self ._model_configs : Dict [str , Dict [str , Any ]] = model_configs
52- self ._data_processed_dir_main : str = data_processed_dir_main
42+ self ._data_file_path : str = data_file_path
43+ self ._classes_file_path : str = classes_file_path
5344 self ._input_dim : Optional [int ] = kwargs .get ("input_dim" , None )
54- self ._total_data_size : int = len (self ._collated_data )
45+ self ._total_data_size : int = None
46+ self ._ensemble_input : list [str ] | Path = self ._process_input_to_ensemble (
47+ data_file_path
48+ )
49+ print (f"Total data size (data.pkl) is { self ._total_data_size } " )
5550
5651 self ._device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
5752
5853 self ._models : Dict [str , LightningModule ] = {}
59- self ._dm_labels : Dict [str , int ] = self ._load_data_module_labels ()
54+ self ._dm_labels : Dict [str , int ] = self ._load_data_module_labels (
55+ classes_file_path
56+ )
6057 self ._num_of_labels : int = len (self ._dm_labels )
58+ print (f"Number of labes for this data is { self ._num_of_labels } " )
6159
6260 self ._num_models_per_label : torch .Tensor = torch .zeros (
6361 1 , self ._num_of_labels , device = self ._device
6462 )
6563 self ._model_queue : Deque [str ] = deque ()
66- self ._collated_data : Optional [XYData ] = None
67- self ._total_data_size : Optional [int ] = None
6864
6965 @classmethod
7066 def _validate_model_configs (cls , model_configs : Dict [str , Dict [str , Any ]]) -> None :
@@ -107,21 +103,43 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No
107103 class_set .add (model_class_path )
108104 labels_set .add (model_labels_path )
109105
110- def _load_data_module_labels (self ) -> dict [str , int ]:
106+ def _process_input_to_ensemble (self , path : str ):
107+ p = Path (path )
108+ if p .is_file ():
109+ smiles_list = []
110+ with open (p , "r" ) as f :
111+ for line in f :
112+ # Skip empty or whitespace-only lines
113+ if line .strip ():
114+ # Split on whitespace and take the first item as the SMILES
115+ smiles = line .strip ().split ()[0 ]
116+ smiles_list .append (smiles )
117+ self ._total_data_size = len (smiles_list )
118+ return smiles_list
119+ elif p .is_dir ():
120+ data_pkl_path = p / "data.pkl"
121+ if not data_pkl_path .exists ():
122+ raise FileNotFoundError ()
123+ self ._total_data_size = len (pd .read_pickle (data_pkl_path ))
124+ return p
125+ else :
126+ raise "Invalid path"
127+
128+ @staticmethod
129+ def _load_data_module_labels (classes_file_path : str ) -> dict [str , int ]:
111130 """
112131 Loads class labels from the classes.txt file and sets internal label mapping.
113132
114133 Raises:
115134 FileNotFoundError: If the expected classes.txt file is not found.
116135 """
117- classes_txt_file = os .path .join (self ._data_processed_dir_main , "classes.txt" )
118- print (f"Loading { classes_txt_file } ...." )
119-
120- if not os .path .exists (classes_txt_file ):
121- raise FileNotFoundError (f"{ classes_txt_file } does not exist" )
136+ classes_file_path = Path (classes_file_path )
137+ if not classes_file_path .exists ():
138+ raise FileNotFoundError (f"{ classes_file_path } does not exist" )
139+ print (f"Loading { classes_file_path } ...." )
122140
123141 dm_labels_dict = {}
124- with open (classes_txt_file , "r" ) as f :
142+ with open (classes_file_path , "r" ) as f :
125143 for line in f :
126144 label = line .strip ()
127145 if label not in dm_labels_dict :
@@ -132,6 +150,7 @@ def run_ensemble(self) -> None:
132150 """
133151 Executes the full ensemble prediction pipeline, aggregating predictions and printing metrics.
134152 """
153+ assert self ._total_data_size is not None and self ._num_of_labels is not None
135154 true_scores = torch .zeros (
136155 self ._total_data_size , self ._num_of_labels , device = self ._device
137156 )
@@ -144,12 +163,12 @@ def run_ensemble(self) -> None:
144163 print (f"Processing model: { model_name } " )
145164
146165 print ("\t Passing model to controller to generate predictions..." )
147- pred_conf_dict , model_props = self ._controller (model_name )
166+ controller_output = self ._controller (model_name , self . _ensemble_input )
148167
149168 print ("\t Passing predictions to consolidator for aggregation..." )
150169 self ._consolidator (
151- pred_conf_dict ,
152- model_props ,
170+ pred_conf_dict = controller_output [ "pred_conf_dict" ] ,
171+ model_props = controller_output [ "model_props" ] ,
153172 true_scores = true_scores ,
154173 false_scores = false_scores ,
155174 )
@@ -168,8 +187,8 @@ def run_ensemble(self) -> None:
168187 @abstractmethod
169188 def _controller (
170189 self ,
171- model : LightningModule ,
172- model_props : Dict [str , torch . Tensor ] ,
190+ model_name : str ,
191+ model_input : list [str ] | Path ,
173192 ** kwargs : Any ,
174193 ) -> Dict [str , torch .Tensor ]:
175194 """
0 commit comments