Skip to content

Commit f01d2c9

Browse files
Feat/Efficient hpo (#227)
* implement new logic * remove bruteforce sampling support * Update optimizer_config.schema.json * remove brute * Update optimizer_config.schema.json * upd contributing md * fix presets * fix optuna search space * upd callback test * implement separate config for hpo * bug fix * Update optimizer_config.schema.json * bug fix * upd test * add utf-8 everywhere * add `utf-8` --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 95be22a commit f01d2c9

30 files changed

+1079
-345
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Note: If mypy shows different errors locally compared to github actions, you sho
5959
```bash
6060
make update
6161
```
62+
But it still doesn't guarantee that the local type checker will give the same errors as CI. This is because CI is configured to check on Python 3.10 and your local python version is probably the latest one.
6263

6364
## Building Documentation
6465

autointent/_dataset/_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def to_json(self, filepath: str | Path) -> None:
133133
path = Path(filepath)
134134
if not path.parent.exists():
135135
path.parent.mkdir(parents=True)
136-
with path.open("w") as file:
136+
with path.open("w", encoding="utf-8") as file:
137137
json.dump(self.to_dict(), file, indent=4, ensure_ascii=False)
138138

139139
def push_to_hub(self, repo_name: str, private: bool = False) -> None:

autointent/_dataset/_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,5 +95,5 @@ def _read(self, filepath: str | Path) -> DatasetReader:
9595
Returns:
9696
DatasetReader: A validated dataset representation.
9797
"""
98-
with Path(filepath).open() as file:
98+
with Path(filepath).open(encoding="utf-8") as file:
9999
return DatasetReader.model_validate(json.load(file))

autointent/_dump_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
151151
msg = f"Attribute {key} of type {type(val)} cannot be dumped to file system."
152152
logger.error(msg)
153153

154-
with (path / Dumper.simple_attrs).open("w") as file:
154+
with (path / Dumper.simple_attrs).open("w", encoding="utf-8") as file:
155155
json.dump(simple_attrs, file, ensure_ascii=False, indent=4)
156156

157157
np.savez(path / Dumper.arrays, allow_pickle=False, **arrays)
@@ -179,7 +179,7 @@ def load( # noqa: C901, PLR0912, PLR0915
179179
if child.name == Dumper.tags:
180180
tags = {tags_dump.name: TagsList.load(tags_dump) for tags_dump in child.iterdir()}
181181
elif child.name == Dumper.simple_attrs:
182-
with child.open() as file:
182+
with child.open(encoding="utf-8") as file:
183183
simple_attrs = json.load(file)
184184
elif child.name == Dumper.arrays:
185185
arrays = dict(np.load(child))

autointent/_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
164164
path: Path to the directory where the model is stored.
165165
override_config: one can override presaved settings
166166
"""
167-
with (Path(path) / cls._metadata_dict_name).open() as file:
167+
with (Path(path) / cls._metadata_dict_name).open(encoding="utf-8") as file:
168168
metadata: EmbedderDumpMetadata = json.load(file)
169169

170170
if override_config is not None:

autointent/_logging/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def setup_logging(level: LogLevel | str, log_filename: Path | str | None = None)
2020
log_filename: specify location of logfile, omit extension as suffix ``.log.jsonl`` will be appended.
2121
"""
2222
config_file = ires.files("autointent._logging").joinpath("config.yaml")
23-
with config_file.open() as f_in:
23+
with config_file.open(encoding="utf-8") as f_in:
2424
config = yaml.safe_load(f_in)
2525

2626
level = LogLevel(level)

autointent/_optimization_config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
from pydantic import BaseModel, PositiveInt
44

5-
from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, HFModelConfig, LoggingConfig
6-
from .custom_types import SamplerType
5+
from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, HFModelConfig, HPOConfig, LoggingConfig
76

87

98
class OptimizationConfig(BaseModel):
@@ -27,7 +26,6 @@ class OptimizationConfig(BaseModel):
2726

2827
transformer_config: HFModelConfig = HFModelConfig()
2928

30-
sampler: SamplerType = "brute"
31-
"""See tutorial on optuna and presets."""
29+
hpo_config: HPOConfig = HPOConfig()
3230

3331
seed: PositiveInt = 42

autointent/_pipeline/_pipeline.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import logging
55
from pathlib import Path
6-
from typing import TYPE_CHECKING, Any, get_args
6+
from typing import TYPE_CHECKING, Any
77

88
import numpy as np
99
import yaml
@@ -15,13 +15,13 @@
1515
DataConfig,
1616
EmbedderConfig,
1717
HFModelConfig,
18+
HPOConfig,
1819
InferenceNodeConfig,
1920
LoggingConfig,
2021
)
2122
from autointent.custom_types import (
2223
ListOfGenericLabels,
2324
NodeType,
24-
SamplerType,
2525
SearchSpacePreset,
2626
SearchSpaceValidationMode,
2727
)
@@ -44,7 +44,6 @@ class Pipeline:
4444
def __init__(
4545
self,
4646
nodes: list[NodeOptimizer] | list[InferenceNode],
47-
sampler: SamplerType = "brute",
4847
seed: int | None = 42,
4948
) -> None:
5049
"""Initialize the pipeline optimizer.
@@ -57,23 +56,19 @@ def __init__(
5756
self._logger = logging.getLogger(__name__)
5857
self.nodes = {node.node_type: node for node in nodes}
5958
self._seed = seed
60-
if sampler not in get_args(SamplerType):
61-
msg = f"Sampler should be one of {get_args(SamplerType)}"
62-
raise ValueError(msg)
63-
64-
self._sampler = sampler
6559

6660
if isinstance(nodes[0], NodeOptimizer):
6761
self.logging_config = LoggingConfig()
6862
self.embedder_config = EmbedderConfig()
6963
self.cross_encoder_config = CrossEncoderConfig()
7064
self.data_config = DataConfig()
7165
self.transformer_config = HFModelConfig()
66+
self.hpo_config = HPOConfig()
7267
elif not isinstance(nodes[0], InferenceNode):
7368
assert_never(nodes)
7469

7570
def set_config(
76-
self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig | HFModelConfig
71+
self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig | HFModelConfig | HPOConfig
7772
) -> None:
7873
"""Set the configuration for the pipeline.
7974
@@ -90,6 +85,8 @@ def set_config(
9085
self.data_config = config
9186
elif isinstance(config, HFModelConfig):
9287
self.transformer_config = config
88+
elif isinstance(config, HPOConfig):
89+
self.hpo_config = config
9390
else:
9491
assert_never(config)
9592

@@ -126,23 +123,23 @@ def from_optimization_config(cls, config: dict[str, Any] | Path | str | Optimiza
126123
if isinstance(config, dict):
127124
dict_params = config
128125
else:
129-
with Path(config).open() as file:
126+
with Path(config).open(encoding="utf-8") as file:
130127
dict_params = yaml.safe_load(file)
131128
optimization_config = OptimizationConfig(**dict_params)
132129

133130
pipeline = cls(
134131
[NodeOptimizer(**node) for node in optimization_config.search_space],
135-
optimization_config.sampler,
136132
optimization_config.seed,
137133
)
138134
pipeline.set_config(optimization_config.logging_config)
139135
pipeline.set_config(optimization_config.data_config)
140136
pipeline.set_config(optimization_config.embedder_config)
141137
pipeline.set_config(optimization_config.cross_encoder_config)
142138
pipeline.set_config(optimization_config.transformer_config)
139+
pipeline.set_config(optimization_config.hpo_config)
143140
return pipeline
144141

145-
def _fit(self, context: Context, sampler: SamplerType) -> None:
142+
def _fit(self, context: Context) -> None:
146143
"""Optimize the pipeline.
147144
148145
Args:
@@ -167,7 +164,7 @@ def _fit(self, context: Context, sampler: SamplerType) -> None:
167164
for node_type in NodeType:
168165
node_optimizer = self.nodes.get(node_type, None)
169166
if node_optimizer is not None:
170-
node_optimizer.fit(context, sampler) # type: ignore[union-attr]
167+
node_optimizer.fit(context) # type: ignore[union-attr]
171168
self.context.callback_handler.end_run()
172169

173170
def _is_inference(self) -> bool:
@@ -182,7 +179,6 @@ def fit(
182179
self,
183180
dataset: Dataset,
184181
refit_after: bool = False,
185-
sampler: SamplerType | None = None,
186182
incompatible_search_space: SearchSpaceValidationMode = "filter",
187183
) -> Context:
188184
"""Optimize the pipeline from dataset.
@@ -206,6 +202,7 @@ def fit(
206202
context.configure_transformer(self.embedder_config)
207203
context.configure_transformer(self.cross_encoder_config)
208204
context.configure_transformer(self.transformer_config)
205+
context.configure_hpo(self.hpo_config)
209206

210207
self.validate_modules(dataset, mode=incompatible_search_space)
211208

@@ -221,10 +218,7 @@ def fit(
221218
"Change settings in LoggerConfig to obtain different behavior."
222219
)
223220

224-
if sampler is None:
225-
sampler = self._sampler
226-
227-
self._fit(context, sampler)
221+
self._fit(context)
228222

229223
if context.logging_config.clear_ram and context.logging_config.dump_modules:
230224
nodes_configs = context.optimization_info.get_inference_nodes_config()
@@ -336,7 +330,7 @@ def load(
336330
embedder_config: one can override presaved settings
337331
cross_encoder_config: one can override presaved settings
338332
"""
339-
with (Path(path) / "inference_config.yaml").open() as file:
333+
with (Path(path) / "inference_config.yaml").open(encoding="utf-8") as file:
340334
inference_nodes_configs: list[dict[str, Any]] = yaml.safe_load(file)
341335

342336
inference_config = [

autointent/_presets/heavy.yaml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,16 @@ search_space:
88
low: 1
99
high: 20
1010
weights: [uniform, distance, closest]
11-
n_trials: 10
1211
- module_name: linear
1312
- module_name: mlknn
1413
k:
1514
low: 1
1615
high: 20
17-
n_trials: 10
1816
- module_name: description
1917
temperature:
2018
low: 0.01
2119
high: 10
2220
log: true
23-
n_trials: 10
2421
- module_name: rerank
2522
k:
2623
low: 10
@@ -29,17 +26,18 @@ search_space:
2926
low: 1
3027
high: 10
3128
weights: [uniform, distance, closest]
32-
n_trials: 15
3329
- node_type: decision
3430
target_metric: decision_accuracy
3531
search_space:
3632
- module_name: threshold
3733
thresh:
3834
low: 0.1
3935
high: 0.9
40-
n_trials: 10
4136
- module_name: argmax
4237
- module_name: jinoos
4338
- module_name: tunable
4439
- module_name: adaptive
45-
sampler: tpe
40+
hpo_config:
41+
sampler: tpe
42+
n_trials: 128 # dont know yet if its good
43+
n_startup_trials: 32

autointent/_presets/heavy_extra.yaml

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)