Skip to content

Commit 4bd00ac

Browse files
committed
private instance var + reader_dir_name param
1 parent 69c5263 commit 4bd00ac

File tree

2 files changed

+45
-35
lines changed

2 files changed

+45
-35
lines changed

chebai/ensemble/base.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@ class EnsembleBase(ABC):
2222
2323
Attributes:
2424
data_processed_dir_main (str): Directory where the processed data is stored.
25-
models (Dict[str, LightningModule]): A dictionary of loaded models.
25+
_models (Dict[str, LightningModule]): A dictionary of loaded models.
2626
model_configs (Dict[str, Dict]): Configuration dictionary for models in the ensemble.
27-
dm_labels (Dict[str, int]): Mapping of label names to integer indices.
27+
_dm_labels (Dict[str, int]): Mapping of label names to integer indices.
2828
"""
2929

3030
def __init__(
31-
self, model_configs: Dict[str, Dict], data_processed_dir_main: str, **kwargs
31+
self,
32+
model_configs: Dict[str, Dict],
33+
data_processed_dir_main: str,
34+
reader_dir_name: str = "smiles_token",
35+
**kwargs,
3236
):
3337
"""
3438
Initializes the ensemble model and loads configuration, models, and labels.
@@ -41,22 +45,25 @@ def __init__(
4145
if bool(kwargs.get("_validate_configs", True)):
4246
self._validate_model_configs(model_configs)
4347

44-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48+
self.model_configs = model_configs
49+
self.data_processed_dir_main = data_processed_dir_main
50+
self.reader_dir_name = reader_dir_name
4551
self.input_dim = kwargs.get("input_dim", None)
46-
self.num_of_labels: Optional[int] = (
52+
53+
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54+
self._num_of_labels: Optional[int] = (
4755
None # will be set by `_load_data_module_labels` method
4856
)
49-
self.data_processed_dir_main = data_processed_dir_main
50-
self.models: Dict[str, LightningModule] = {}
51-
self.model_configs = model_configs
52-
self.dm_labels: Dict[str, int] = {}
57+
self._models: Dict[str, LightningModule] = {}
58+
self._dm_labels: Dict[str, int] = {}
5359

5460
self._load_data_module_labels()
5561
self._num_models_per_label: torch.Tensor = torch.zeros(
56-
1, self.num_of_labels, device=self.device
62+
1, self._num_of_labels, device=self._device
5763
)
5864
self._model_queue: Deque = deque()
5965
self._collated_data = None
66+
self._total_data_size: Optional[int] = None
6067

6168
@classmethod
6269
def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
@@ -121,14 +128,17 @@ def _load_data_module_labels(self):
121128
else:
122129
with open(classes_txt_file, "r") as f:
123130
for line in f:
124-
if line.strip() not in self.dm_labels:
125-
self.dm_labels[line.strip()] = len(self.dm_labels)
126-
self.num_of_labels = len(self.dm_labels)
131+
if line.strip() not in self._dm_labels:
132+
self._dm_labels[line.strip()] = len(self._dm_labels)
133+
self._num_of_labels = len(self._dm_labels)
127134

128135
def run_ensemble(self):
129-
batch_size = 10
130-
true_scores = torch.zeros(batch_size, self.num_of_labels, device=self.device)
131-
false_scores = torch.zeros(batch_size, self.num_of_labels, device=self.device)
136+
true_scores = torch.zeros(
137+
self._total_data_size, self._num_of_labels, device=self._device
138+
)
139+
false_scores = torch.zeros(
140+
self._total_data_size, self._num_of_labels, device=self._device
141+
)
132142

133143
while self._model_queue:
134144
model_name = self._model_queue.popleft()
@@ -156,8 +166,8 @@ def run_ensemble(self):
156166
print_metrics(
157167
final_preds,
158168
self._collated_data.y,
159-
self.device,
160-
classes=list(self.dm_labels.keys()),
169+
self._device,
170+
classes=list(self._dm_labels.keys()),
161171
)
162172

163173
def _load_model_and_its_props(self, model_name):
@@ -209,33 +219,33 @@ def _generate_model_label_props(self, labels_path: str):
209219

210220
model_label_indices, tpv_label_values, fpv_label_values = [], [], []
211221
for label in labels_dict.keys():
212-
if label in self.dm_labels:
222+
if label in self._dm_labels:
213223
try:
214224
self._validate_model_labels_json_element(labels_dict[label])
215225
except Exception as e:
216226
raise Exception(f"Label '{label}' has an unexpected error: {e}")
217227

218-
model_label_indices.append(self.dm_labels[label])
228+
model_label_indices.append(self._dm_labels[label])
219229
tpv_label_values.append(labels_dict[label]["TPV"])
220230
fpv_label_values.append(labels_dict[label]["FPV"])
221231

222232
if not all([model_label_indices, tpv_label_values, fpv_label_values]):
223233
raise ValueError(f"Values are empty for labels of the model")
224234

225235
# Create masks to apply predictions only to known classes
226-
mask = torch.zeros(self.num_of_labels, device=self.device, dtype=torch.bool)
227-
mask[torch.tensor(model_label_indices, dtype=torch.int, device=self.device)] = (
228-
True
229-
)
236+
mask = torch.zeros(self._num_of_labels, device=self._device, dtype=torch.bool)
237+
mask[
238+
torch.tensor(model_label_indices, dtype=torch.int, device=self._device)
239+
] = True
230240

231-
tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self.device)
232-
fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self.device)
241+
tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device)
242+
fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self._device)
233243

234244
tpv_tensor[mask] = torch.tensor(
235-
tpv_label_values, dtype=torch.float, device=self.device
245+
tpv_label_values, dtype=torch.float, device=self._device
236246
)
237247
fpv_tensor[mask] = torch.tensor(
238-
fpv_label_values, dtype=torch.float, device=self.device
248+
fpv_label_values, dtype=torch.float, device=self._device
239249
)
240250
self._num_models_per_label += mask
241251
return {"mask": mask, "tpv_tensor": tpv_tensor, "fpv_tensor": fpv_tensor}

chebai/ensemble/controller.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@ def __init__(self, **kwargs):
1818

1919
self._collated_data = self._load_and_collate_data()
2020
self.input_dim = len(self._collated_data.x[0])
21-
self.total_data_size: int = len(self._collated_data)
21+
self._total_data_size: int = len(self._collated_data)
2222

2323
def _load_and_collate_data(self):
2424
data = torch.load(
25-
os.path.join(self.data_processed_dir_main, "smiles_token", "data.pt"),
25+
os.path.join(self.data_processed_dir_main, self.reader_dir_name, "data.pt"),
2626
weights_only=False,
27-
map_location=self.device,
27+
map_location=self._device,
2828
)
2929
collated_data = self._collator(data)
30-
collated_data.x = collated_data.to_x(self.device)
30+
collated_data.x = collated_data.to_x(self._device)
3131
if collated_data.y is not None:
32-
collated_data.y = collated_data.to_y(self.device)
32+
collated_data.y = collated_data.to_y(self._device)
3333
return collated_data
3434

3535
def _forward_pass(self, model: ChebaiBaseNet):
@@ -42,10 +42,10 @@ def _get_pred_conf_from_model_output(self, model_output, model_label_mask):
4242
# Consider logits and confidence only for valid classes
4343
sigmoid_logits = torch.sigmoid(model_output["logits"])
4444
prediction = torch.full(
45-
(self.total_data_size, self.num_of_labels), -1, dtype=torch.bool
45+
(self._total_data_size, self._num_of_labels), -1, dtype=torch.bool
4646
)
4747
confidence = torch.full(
48-
(self.total_data_size, self.num_of_labels), -1, dtype=torch.float
48+
(self._total_data_size, self._num_of_labels), -1, dtype=torch.float
4949
)
5050
prediction[:, model_label_mask] = sigmoid_logits > 0.5
5151
confidence[:, model_label_mask] = 2 * torch.abs(sigmoid_logits - 0.5)

0 commit comments

Comments
 (0)