1+ """Configuration for the optimization process."""
2+
13from dataclasses import dataclass , field
2- from datetime import datetime
34from pathlib import Path
45from typing import Any
56
67from hydra .core .config_store import ConfigStore
78from omegaconf import MISSING
89
9- from .name import generate_name
10+ from .name import get_run_name
1011
1112
1213@dataclass
1314class DataConfig :
15+ """Configuration for the data used in the optimization process."""
16+
1417 train_path : str | Path = MISSING
18+ """Path to the training data"""
1519 test_path : Path | None = None
20+ """Path to the testing data. If None, no testing data will be used"""
1621 force_multilabel : bool = False
22+ """Force multilabel classification even if the data is multiclass"""
1723
1824
1925@dataclass
2026class TaskConfig :
21- """TODO presets """
27+ """Configuration for the task to optimize. """
2228
2329 search_space_path : Path | None = None
30+ """Path to the search space configuration file. If None, the default search space will be used"""
2431
2532
2633@dataclass
2734class LoggingConfig :
35+ """Configuration for the logging."""
36+
2837 run_name : str | None = None
38+ """Name of the run. If None, a random name will be generated"""
2939 dirpath : Path | None = None
40+ """Path to the directory where the logs will be saved.
41+ If None, the logs will be saved in the current working directory"""
3042 dump_dir : Path | None = None
43+ """Path to the directory where the modules will be dumped. If None, the modules will not be dumped"""
3144 dump_modules : bool = False
45+ """Whether to dump the modules or not"""
3246 clear_ram : bool = True
47+ """Whether to clear the RAM after dumping the modules"""
3348
3449 def __post_init__ (self ) -> None :
50+ """Define the run name, directory path and dump directory."""
3551 self .define_run_name ()
3652 self .define_dirpath ()
3753 self .define_dump_dir ()
3854
3955 def define_run_name (self ) -> None :
40- if self .run_name is None :
41- self .run_name = generate_name ()
42- self .run_name = f"{ self .run_name } _{ datetime .now ().strftime ('%m-%d-%Y_%H-%M-%S' )} " # noqa: DTZ005
56+ """Define the run name. If None, a random name will be generated."""
57+ self .run_name = get_run_name (self .run_name )
4358
4459 def define_dirpath (self ) -> None :
60+ """Define the directory path. If None, the logs will be saved in the current working directory."""
4561 dirpath = Path .cwd () / "runs" if self .dirpath is None else self .dirpath
4662 if self .run_name is None :
4763 raise ValueError
4864 self .dirpath = dirpath / self .run_name
4965
5066 def define_dump_dir (self ) -> None :
67+ """Define the dump directory. If None, the modules will not be dumped."""
5168 if self .dump_dir is None :
5269 if self .dirpath is None :
5370 raise ValueError
@@ -56,32 +73,60 @@ def define_dump_dir(self) -> None:
5673
5774@dataclass
5875class VectorIndexConfig :
76+ """Configuration for the vector index."""
77+
5978 db_dir : Path | None = None
79+ """Path to the directory where the vector index database will be saved. If None, the database will not be saved"""
6080 device : str = "cpu"
81+ """Device to use for the vector index. Can be 'cpu', 'cuda', 'cuda:0', 'mps', etc."""
6182 save_db : bool = False
83+ """Whether to save the vector index database or not"""
6284
6385
6486@dataclass
6587class AugmentationConfig :
88+ """Configuration for the augmentation."""
89+
6690 regex_sampling : int = 0
91+ """Number of regex samples to generate"""
6792 multilabel_generation_config : str | None = None
93+ """Path to the multilabel generation configuration file. If None, the default configuration will be used"""
6894
6995
7096@dataclass
7197class EmbedderConfig :
98+ """
99+ Configuration for the embedder.
100+
101+ The embedder is used to embed the data before training the model. These parameters
102+ will be applied to the embedder used in the optimization process in vector db.
103+ Only one model can be used globally.
104+ """
105+
72106 batch_size : int = 32
107+ """Batch size for the embedder"""
73108 max_length : int | None = None
109+ """Max length for the embedder. If None, the max length will be taken from model config"""
74110
75111
76112@dataclass
77113class OptimizationConfig :
114+ """Configuration for the optimization process."""
115+
78116 seed : int = 0
117+ """Seed for the random number generator"""
79118 data : DataConfig = field (default_factory = DataConfig )
119+ """Configuration for the data used in the optimization process"""
80120 task : TaskConfig = field (default_factory = TaskConfig )
121+ """Configuration for the task to optimize"""
81122 logs : LoggingConfig = field (default_factory = LoggingConfig )
123+ """Configuration for the logging"""
82124 vector_index : VectorIndexConfig = field (default_factory = VectorIndexConfig )
125+ """Configuration for the vector index"""
83126 augmentation : AugmentationConfig = field (default_factory = AugmentationConfig )
127+ """Configuration for the augmentation"""
84128 embedder : EmbedderConfig = field (default_factory = EmbedderConfig )
129+ """Configuration for the embedder"""
85130
86131 defaults : list [Any ] = field (
87132 default_factory = lambda : [
0 commit comments