|
1 | 1 | import os.path |
2 | 2 | from abc import ABC |
| 3 | +from collections import deque |
| 4 | +from typing import Deque |
3 | 5 |
|
4 | 6 | import torch |
5 | 7 |
|
6 | | -from chebai.ensemble.base import EnsembleBase |
7 | 8 | from chebai.models import ChebaiBaseNet |
8 | 9 | from chebai.preprocessing.collate import RaggedCollator |
9 | 10 |
|
| 11 | +from .base import EnsembleBase |
| 12 | + |
10 | 13 |
|
11 | 14 | class _Controller(EnsembleBase, ABC): |
12 | 15 | def __init__(self, **kwargs): |
13 | 16 | super().__init__(**kwargs) |
14 | 17 | self._collator = RaggedCollator() |
15 | 18 |
|
16 | 19 | self._collated_data = self._load_and_collate_data() |
| 20 | + self.input_dim = len(self._collated_data.x[0]) |
17 | 21 | self.total_data_size: int = len(self._collated_data) |
18 | 22 |
|
19 | 23 | def _load_and_collate_data(self): |
20 | 24 | data = torch.load( |
21 | | - os.path.join(self.data_processed_dir_main, "data.pt"), |
| 25 | + os.path.join(self.data_processed_dir_main, "smiles_token", "data.pt"), |
22 | 26 | weights_only=False, |
23 | 27 | map_location=self.device, |
24 | 28 | ) |
@@ -51,7 +55,7 @@ def _get_pred_conf_from_model_output(self, model_output, model_label_mask): |
51 | 55 | class NoActivationCondition(_Controller): |
52 | 56 | def __init__(self, **kwargs): |
53 | 57 | super().__init__(**kwargs) |
54 | | - self._model_queue = list(self.model_configs.keys()) |
| 58 | + self._model_queue: Deque = deque(list(self.model_configs.keys())) |
55 | 59 |
|
56 | 60 | def _controller(self, model, model_props, **kwargs): |
57 | 61 | model_output = self._forward_pass(model) |
|
0 commit comments