Skip to content

Commit b6669ec

Browse files
committed
rename some variables
1 parent 5072810 commit b6669ec

File tree

7 files changed

+24
-24
lines changed

7 files changed

+24
-24
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from autointent import Context, Dataset
1212
from autointent.configs import CrossEncoderConfig, EmbedderConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
13-
from autointent.custom_types import ListOfGenericLabels, NodeType, TuningType, ValidationScheme
13+
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType, ValidationScheme
1414
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
1515
from autointent.nodes import InferenceNode, NodeOptimizer
1616
from autointent.utils import load_default_search_space, load_search_space
@@ -92,7 +92,7 @@ def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline":
9292
"""
9393
return cls.from_search_space(search_space=load_default_search_space(multilabel), seed=seed)
9494

95-
def _fit(self, context: Context, tuning: TuningType = "brute") -> None:
95+
def _fit(self, context: Context, sampler: SamplerType = "brute") -> None:
9696
"""
9797
Optimize the pipeline.
9898
@@ -107,7 +107,7 @@ def _fit(self, context: Context, tuning: TuningType = "brute") -> None:
107107
for node_type in NodeType:
108108
node_optimizer = self.nodes.get(node_type, None)
109109
if node_optimizer is not None:
110-
node_optimizer.fit(context, tuning) # type: ignore[union-attr]
110+
node_optimizer.fit(context, sampler) # type: ignore[union-attr]
111111
if not context.vector_index_config.save_db:
112112
self._logger.info("removing vector database from file system...")
113113
# TODO clear cache from appdirs
@@ -127,7 +127,7 @@ def fit(
127127
scheme: ValidationScheme = "ho",
128128
n_folds: int = 3,
129129
refit_after: bool = False,
130-
tuning: TuningType = "brute",
130+
sampler: SamplerType = "brute",
131131
) -> Context:
132132
"""
133133
Optimize the pipeline from dataset.
@@ -145,7 +145,7 @@ def fit(
145145
context.configure_vector_index(self.vector_index_config, self.embedder_config)
146146
context.configure_cross_encoder(self.cross_encoder_config)
147147
self.validate_modules(dataset)
148-
self._fit(context, tuning)
148+
self._fit(context, sampler)
149149

150150
if context.is_ram_to_clear():
151151
nodes_configs = context.optimization_info.get_inference_nodes_config()

autointent/configs/_optimization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pydantic import BaseModel, Field
66

7-
from autointent.custom_types import TuningType, ValidationScheme
7+
from autointent.custom_types import SamplerType, ValidationScheme
88

99
from ._name import get_run_name
1010

@@ -25,7 +25,7 @@ class TaskConfig(BaseModel):
2525

2626
search_space_path: Path | None = None
2727
"""Path to the search space configuration file. If None, the default search space will be used"""
28-
sampler: TuningType = "brute"
28+
sampler: SamplerType = "brute"
2929

3030

3131
class LoggingConfig(BaseModel):

autointent/custom_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,5 @@ class Split:
7171
INTENTS = "intents"
7272

7373

74-
TuningType = Literal["brute", "bayes", "random"]
74+
SamplerType = Literal["brute", "tpe", "random"]
7575
ValidationScheme = Literal["ho", "cv"]

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from autointent import Dataset
1515
from autointent.context import Context
16-
from autointent.custom_types import NodeType, TuningType
16+
from autointent.custom_types import NodeType, SamplerType
1717
from autointent.nodes._nodes_info import NODES_INFO
1818

1919

@@ -59,7 +59,7 @@ def __init__(
5959
self.modules_search_spaces = search_space
6060
self._logger = logging.getLogger(__name__) # TODO solve duplicate logging messages problem
6161

62-
def fit(self, context: Context, tuning: TuningType = "brute") -> None:
62+
def fit(self, context: Context, sampler: SamplerType = "brute") -> None:
6363
"""
6464
Fit the node optimizer.
6565
@@ -73,19 +73,19 @@ def fit(self, context: Context, tuning: TuningType = "brute") -> None:
7373
n_trials = None
7474
if "n_trials" in search_space:
7575
n_trials = search_space.pop("n_trials")
76-
if tuning == "bayes":
77-
sampler = optuna.samplers.TPESampler(seed=context.seed)
76+
if sampler == "tpe":
77+
sampler_instance = optuna.samplers.TPESampler(seed=context.seed)
7878
n_trials = n_trials or 10
79-
elif tuning == "brute":
80-
sampler = optuna.samplers.BruteForceSampler(seed=context.seed) # type: ignore[assignment]
79+
elif sampler == "brute":
80+
sampler_instance = optuna.samplers.BruteForceSampler(seed=context.seed) # type: ignore[assignment]
8181
n_trials = None
82-
elif tuning == "random":
83-
sampler = optuna.samplers.RandomSampler(seed=context.seed) # type: ignore[assignment]
82+
elif sampler == "random":
83+
sampler_instance = optuna.samplers.RandomSampler(seed=context.seed) # type: ignore[assignment]
8484
n_trials = n_trials or 10
8585
else:
86-
msg = f"Unexpected sampler: {tuning}"
86+
msg = f"Unexpected sampler: {sampler}"
8787
raise ValueError(msg)
88-
study = optuna.create_study(direction="maximize", sampler=sampler)
88+
study = optuna.create_study(direction="maximize", sampler=sampler_instance)
8989
optuna.logging.set_verbosity(optuna.logging.WARNING)
9090
obj = partial(self.objective, module_name=module_name, search_space=search_space, context=context)
9191
study.optimize(obj, n_trials=n_trials)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def dataset_unsplitted():
2727
return Dataset.from_json(path)
2828

2929

30-
TaskType = Literal["multiclass", "multilabel", "description", "bayes"]
30+
TaskType = Literal["multiclass", "multilabel", "description", "optuna"]
3131

3232

3333
def get_search_space_path(task_type: TaskType):

tests/pipeline/test_optimization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,20 @@
1212

1313

1414
@pytest.mark.parametrize(
15-
"tuning",
16-
["bayes", "random"],
15+
"sampler",
16+
["tpe", "random"],
1717
)
18-
def test_bayes(dataset, tuning):
18+
def test_bayes(dataset, sampler):
1919
project_dir = setup_environment()
20-
search_space = get_search_space("bayes")
20+
search_space = get_search_space("optuna")
2121

2222
pipeline_optimizer = Pipeline.from_search_space(search_space)
2323

2424
pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True))
2525
pipeline_optimizer.set_config(VectorIndexConfig())
2626
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu"))
2727

28-
pipeline_optimizer.fit(dataset, scheme="ho", refit_after=False, tuning=tuning)
28+
pipeline_optimizer.fit(dataset, scheme="ho", refit_after=False, sampler=sampler)
2929

3030

3131
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)