Skip to content

Commit 77a6ae5

Browse files
authored
init docs (#45)
1 parent 6efac19 commit 77a6ae5

File tree

89 files changed

+3548
-625
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+3548
-625
lines changed

.github/workflows/ruff.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ jobs:
55
runs-on: ubuntu-latest
66
steps:
77
- uses: actions/checkout@v4
8-
- uses: chartboost/ruff-action@v1
8+
- uses: astral-sh/ruff-action@v1

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,7 @@ docs:
3333

3434
.PHONY: serve-docs
3535
serve-docs: docs
36-
$(poetry) python -m http.server -d docs/build/html 8333
36+
$(poetry) python -m http.server -d docs/build/html 8333
37+
38+
.PHONY: all
39+
all: lint

autointent/configs/__init__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from .inference_cli import InferenceConfig
2+
from .inference_pipeline import InferencePipelineConfig
3+
from .node import InferenceNodeConfig, NodeOptimizerConfig
4+
from .optimization_cli import (
5+
AugmentationConfig,
6+
DataConfig,
7+
EmbedderConfig,
8+
LoggingConfig,
9+
OptimizationConfig,
10+
TaskConfig,
11+
VectorIndexConfig,
12+
)
13+
from .pipeline_optimizer import PipelineOptimizerConfig
14+
15+
__all__ = [
16+
"AugmentationConfig",
17+
"DataConfig",
18+
"EmbedderConfig",
19+
"InferenceConfig",
20+
"InferenceNodeConfig",
21+
"InferencePipelineConfig",
22+
"LoggingConfig",
23+
"NodeOptimizerConfig",
24+
"OptimizationConfig",
25+
"PipelineOptimizerConfig",
26+
"TaskConfig",
27+
"VectorIndexConfig",
28+
]

autointent/configs/inference_cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Inference CLI configuration module."""
2+
13
from dataclasses import dataclass
24

35
from hydra.core.config_store import ConfigStore
@@ -7,11 +9,18 @@
79

810
@dataclass
911
class InferenceConfig:
12+
"""Configuration for the inference process."""
13+
1014
data_path: str
15+
"""Path to the file containing the data for prediction"""
1116
source_dir: str
17+
"""Path to the directory containing the inference config"""
1218
output_path: str
19+
"""Path to the file where the predictions will be saved"""
1320
log_level: LogLevel = LogLevel.ERROR
21+
"""Logging level"""
1422
with_metadata: bool = False
23+
"""Whether to save metadata along with the predictions"""
1524

1625

1726
cs = ConfigStore.instance()

autointent/configs/inference_pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Configuration for the inference pipeline."""
2+
13
from dataclasses import dataclass
24

35
from omegaconf import MISSING
@@ -7,5 +9,8 @@
79

810
@dataclass
911
class InferencePipelineConfig:
12+
"""Configuration for the inference pipeline."""
13+
1014
nodes: list[InferenceNodeConfig] = MISSING
15+
"""List of nodes in the inference pipeline"""
1116
_target_: str = "autointent.pipeline.InferencePipeline"

autointent/configs/name.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Random name generator."""
2+
13
import random
24
from datetime import datetime
35

@@ -340,12 +342,23 @@
340342

341343

342344
def generate_name() -> str:
345+
"""
346+
Generate a random name for a run.
347+
348+
:return: Random name
349+
"""
343350
adjective = random.choice(adjectives)
344351
noun = random.choice(nouns)
345352
return f"{adjective}_{noun}"
346353

347354

348355
def get_run_name(run_name: str | None = None) -> str:
356+
"""
357+
Get a run name.
358+
359+
:param run_name: Run name. If None, generate a random name
360+
:return: Run name with a timestamp
361+
"""
349362
if run_name is None:
350363
run_name = generate_name()
351364
return f"{run_name}_{datetime.now().strftime('%m-%d-%Y_%H-%M-%S')}" # noqa: DTZ005

autointent/configs/node.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,36 @@
1+
"""Configuration for the nodes."""
2+
13
from dataclasses import dataclass
24
from typing import Any
35

46
from omegaconf import MISSING
57

8+
from autointent.custom_types import NodeType, NodeTypeType
9+
610

711
@dataclass
812
class InferenceNodeConfig:
9-
node_type: str = MISSING
10-
module_type: str = MISSING
13+
"""Configuration for the inference node."""
14+
15+
node_type: NodeTypeType = MISSING
16+
"""Type of the node. Should be one of the NODE_TYPES"""
17+
module_type: str = MISSING # TODO: add custom type
18+
"""Type of the module. Should be one of the Module"""
1119
module_config: dict[str, Any] = MISSING
20+
"""Configuration of the module"""
1221
load_path: str | None = None
22+
"""Path to the module dump. If None, the module will be trained from scratch"""
1323
_target_: str = "autointent.nodes.InferenceNode"
1424

1525

1626
@dataclass
1727
class NodeOptimizerConfig:
18-
node_type: str = MISSING
28+
"""Configuration for the node optimizer."""
29+
30+
node_type: NodeType = MISSING
31+
"""Type of the node. Should be one of the NODE_TYPES"""
1932
search_space: list[dict[str, Any]] = MISSING
20-
metric: str = MISSING
33+
"""Search space for the optimization"""
34+
metric: str = MISSING # TODO: add custom type
35+
"""Metric to optimize"""
2136
_target_: str = "autointent.nodes.NodeOptimizer"

autointent/configs/optimization_cli.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,70 @@
1+
"""Configuration for the optimization process."""
2+
13
from dataclasses import dataclass, field
2-
from datetime import datetime
34
from pathlib import Path
45
from typing import Any
56

67
from hydra.core.config_store import ConfigStore
78
from omegaconf import MISSING
89

9-
from .name import generate_name
10+
from .name import get_run_name
1011

1112

1213
@dataclass
1314
class DataConfig:
15+
"""Configuration for the data used in the optimization process."""
16+
1417
train_path: str | Path = MISSING
18+
"""Path to the training data"""
1519
test_path: Path | None = None
20+
"""Path to the testing data. If None, no testing data will be used"""
1621
force_multilabel: bool = False
22+
"""Force multilabel classification even if the data is multiclass"""
1723

1824

1925
@dataclass
2026
class TaskConfig:
21-
"""TODO presets"""
27+
"""Configuration for the task to optimize."""
2228

2329
search_space_path: Path | None = None
30+
"""Path to the search space configuration file. If None, the default search space will be used"""
2431

2532

2633
@dataclass
2734
class LoggingConfig:
35+
"""Configuration for the logging."""
36+
2837
run_name: str | None = None
38+
"""Name of the run. If None, a random name will be generated"""
2939
dirpath: Path | None = None
40+
"""Path to the directory where the logs will be saved.
41+
If None, the logs will be saved in the current working directory"""
3042
dump_dir: Path | None = None
43+
"""Path to the directory where the modules will be dumped. If None, the modules will not be dumped"""
3144
dump_modules: bool = False
45+
"""Whether to dump the modules or not"""
3246
clear_ram: bool = True
47+
"""Whether to clear the RAM after dumping the modules"""
3348

3449
def __post_init__(self) -> None:
50+
"""Define the run name, directory path and dump directory."""
3551
self.define_run_name()
3652
self.define_dirpath()
3753
self.define_dump_dir()
3854

3955
def define_run_name(self) -> None:
40-
if self.run_name is None:
41-
self.run_name = generate_name()
42-
self.run_name = f"{self.run_name}_{datetime.now().strftime('%m-%d-%Y_%H-%M-%S')}" # noqa: DTZ005
56+
"""Define the run name. If None, a random name will be generated."""
57+
self.run_name = get_run_name(self.run_name)
4358

4459
def define_dirpath(self) -> None:
60+
"""Define the directory path. If None, the logs will be saved in the current working directory."""
4561
dirpath = Path.cwd() / "runs" if self.dirpath is None else self.dirpath
4662
if self.run_name is None:
4763
raise ValueError
4864
self.dirpath = dirpath / self.run_name
4965

5066
def define_dump_dir(self) -> None:
67+
"""Define the dump directory. If None, the modules will not be dumped."""
5168
if self.dump_dir is None:
5269
if self.dirpath is None:
5370
raise ValueError
@@ -56,32 +73,60 @@ def define_dump_dir(self) -> None:
5673

5774
@dataclass
5875
class VectorIndexConfig:
76+
"""Configuration for the vector index."""
77+
5978
db_dir: Path | None = None
79+
"""Path to the directory where the vector index database will be saved. If None, the database will not be saved"""
6080
device: str = "cpu"
81+
"""Device to use for the vector index. Can be 'cpu', 'cuda', 'cuda:0', 'mps', etc."""
6182
save_db: bool = False
83+
"""Whether to save the vector index database or not"""
6284

6385

6486
@dataclass
6587
class AugmentationConfig:
88+
"""Configuration for the augmentation."""
89+
6690
regex_sampling: int = 0
91+
"""Number of regex samples to generate"""
6792
multilabel_generation_config: str | None = None
93+
"""Path to the multilabel generation configuration file. If None, the default configuration will be used"""
6894

6995

7096
@dataclass
7197
class EmbedderConfig:
98+
"""
99+
Configuration for the embedder.
100+
101+
The embedder is used to embed the data before training the model. These parameters
102+
will be applied to the embedder used in the optimization process in vector db.
103+
Only one model can be used globally.
104+
"""
105+
72106
batch_size: int = 32
107+
"""Batch size for the embedder"""
73108
max_length: int | None = None
109+
"""Max length for the embedder. If None, the max length will be taken from model config"""
74110

75111

76112
@dataclass
77113
class OptimizationConfig:
114+
"""Configuration for the optimization process."""
115+
78116
seed: int = 0
117+
"""Seed for the random number generator"""
79118
data: DataConfig = field(default_factory=DataConfig)
119+
"""Configuration for the data used in the optimization process"""
80120
task: TaskConfig = field(default_factory=TaskConfig)
121+
"""Configuration for the task to optimize"""
81122
logs: LoggingConfig = field(default_factory=LoggingConfig)
123+
"""Configuration for the logging"""
82124
vector_index: VectorIndexConfig = field(default_factory=VectorIndexConfig)
125+
"""Configuration for the vector index"""
83126
augmentation: AugmentationConfig = field(default_factory=AugmentationConfig)
127+
"""Configuration for the augmentation"""
84128
embedder: EmbedderConfig = field(default_factory=EmbedderConfig)
129+
"""Configuration for the embedder"""
85130

86131
defaults: list[Any] = field(
87132
default_factory=lambda: [

autointent/configs/pipeline_optimizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Pipeline configuration."""
2+
13
from dataclasses import dataclass
24

35
from omegaconf import MISSING
@@ -7,5 +9,8 @@
79

810
@dataclass
911
class PipelineOptimizerConfig:
12+
"""Configuration for the pipeline optimizer."""
13+
1014
nodes: list[NodeOptimizerConfig] = MISSING
15+
"""List of the nodes to optimize"""
1116
_target_: str = "autointent.pipeline.PipelineOptimizer"

0 commit comments

Comments
 (0)