|
1 | 1 | """Configuration for the nodes.""" |
2 | 2 |
|
3 | | -from dataclasses import asdict, dataclass |
4 | 3 | from typing import Any |
5 | 4 |
|
6 | 5 | from autointent.custom_types import NodeType |
7 | 6 |
|
8 | 7 | from ._transformers import CrossEncoderConfig, EmbedderConfig |
9 | 8 |
|
10 | 9 |
|
11 | | -@dataclass |
12 | 10 | class InferenceNodeConfig: |
13 | 11 | """Configuration for the inference node.""" |
14 | 12 |
|
15 | | - node_type: NodeType |
16 | | - """Type of the node.""" |
17 | | - module_name: str |
18 | | - """Name of module which is specified as :py:attr:`autointent.modules.base.BaseModule.name`.""" |
19 | | - module_config: dict[str, Any] |
20 | | - """Hyperparameters of underlying module.""" |
21 | | - load_path: str |
22 | | - """Path to the module dump.""" |
23 | | - embedder_config: EmbedderConfig | None = None |
24 | | - """One can override presaved embedder config while loading from file system.""" |
25 | | - cross_encoder_config: CrossEncoderConfig | None = None |
26 | | - """One can override presaved cross encoder config while loading from file system.""" |
| 13 | + def __init__( |
| 14 | + self, |
| 15 | + node_type: NodeType, |
| 16 | + module_name: str, |
| 17 | + module_config: dict[str, Any], |
| 18 | + load_path: str, |
| 19 | + embedder_config: EmbedderConfig | None = None, |
| 20 | + cross_encoder_config: CrossEncoderConfig | None = None, |
| 21 | + ) -> None: |
| 22 | + """Initialize the InferenceNodeConfig. |
| 23 | +
|
| 24 | + Args: |
| 25 | + node_type: Type of the node. |
| 26 | + module_name: Name of module which is specified as :py:attr:`autointent.modules.base.BaseModule.name`. |
| 27 | + module_config: Hyperparameters of underlying module. |
| 28 | + load_path: Path to the module dump. |
| 29 | + embedder_config: One can override presaved embedder config while loading from file system. |
| 30 | + cross_encoder_config: One can override presaved cross encoder config while loading from file system. |
| 31 | + """ |
| 32 | + self.node_type = node_type |
| 33 | + self.module_name = module_name |
| 34 | + self.module_config = module_config |
| 35 | + self.load_path = load_path |
| 36 | + |
| 37 | + if embedder_config is not None: |
| 38 | + self.embedder_config = embedder_config |
| 39 | + if cross_encoder_config is not None: |
| 40 | + self.cross_encoder_config = cross_encoder_config |
27 | 41 |
|
28 | 42 | def asdict(self) -> dict[str, Any]: |
29 | | - """Convert config to dict format.""" |
30 | | - res = asdict(self) |
31 | | - if self.embedder_config is not None: |
32 | | - res["embedder_config"] = self.embedder_config.model_dump() |
33 | | - else: |
34 | | - res.pop("embedder_config") |
35 | | - if self.cross_encoder_config is not None: |
36 | | - res["cross_encoder_config"] = self.cross_encoder_config.model_dump() |
37 | | - else: |
38 | | - res.pop("cross_encoder_config") |
39 | | - return res |
| 43 | + """Convert the InferenceNodeConfig to a dictionary. |
| 44 | +
|
| 45 | + Returns: |
| 46 | + A dictionary representation of the InferenceNodeConfig. |
| 47 | + """ |
| 48 | + result = { |
| 49 | + "node_type": self.node_type, |
| 50 | + "module_name": self.module_name, |
| 51 | + "module_config": self.module_config, |
| 52 | + "load_path": self.load_path, |
| 53 | + } |
| 54 | + |
| 55 | + if hasattr(self, "embedder_config"): |
| 56 | + result["embedder_config"] = self.embedder_config.model_dump() |
| 57 | + if hasattr(self, "cross_encoder_config"): |
| 58 | + result["cross_encoder_config"] = self.cross_encoder_config.model_dump() |
| 59 | + |
| 60 | + return result |
0 commit comments