Skip to content

Commit 69c5263

Browse files
committed
ensemble minor changes
1 parent 825916e commit 69c5263

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

chebai/ensemble/consolidator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC
22

3-
from chebai.ensemble.base import EnsembleBase
3+
from .base import EnsembleBase
44

55

66
class WeightedMajorityVoting(EnsembleBase, ABC):

chebai/ensemble/controller.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
11
import os.path
22
from abc import ABC
3+
from collections import deque
4+
from typing import Deque
35

46
import torch
57

6-
from chebai.ensemble.base import EnsembleBase
78
from chebai.models import ChebaiBaseNet
89
from chebai.preprocessing.collate import RaggedCollator
910

11+
from .base import EnsembleBase
12+
1013

1114
class _Controller(EnsembleBase, ABC):
1215
def __init__(self, **kwargs):
1316
super().__init__(**kwargs)
1417
self._collator = RaggedCollator()
1518

1619
self._collated_data = self._load_and_collate_data()
20+
self.input_dim = len(self._collated_data.x[0])
1721
self.total_data_size: int = len(self._collated_data)
1822

1923
def _load_and_collate_data(self):
2024
data = torch.load(
21-
os.path.join(self.data_processed_dir_main, "data.pt"),
25+
os.path.join(self.data_processed_dir_main, "smiles_token", "data.pt"),
2226
weights_only=False,
2327
map_location=self.device,
2428
)
@@ -51,7 +55,7 @@ def _get_pred_conf_from_model_output(self, model_output, model_label_mask):
5155
class NoActivationCondition(_Controller):
5256
def __init__(self, **kwargs):
5357
super().__init__(**kwargs)
54-
self._model_queue = list(self.model_configs.keys())
58+
self._model_queue: Deque = deque(list(self.model_configs.keys()))
5559

5660
def _controller(self, model, model_props, **kwargs):
5761
model_output = self._forward_pass(model)

0 commit comments

Comments
 (0)