Skip to content

Commit 86384cf

Browse files
Samoedgithub-actions[bot]voorhs
authored
add node validators (#177)
* add node validators * add comments * Update optimizer_config.schema.json * rename bert model * lint * fixes * fix test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]>
1 parent 2a6a952 commit 86384cf

File tree

10 files changed

+415
-43
lines changed

10 files changed

+415
-43
lines changed

autointent/modules/__init__.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,25 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
5454
)
5555

5656

57-
__all__ = [] # type: ignore[var-annotated]
57+
__all__ = [
58+
"AdaptiveDecision",
59+
"ArgmaxDecision",
60+
"BaseDecision",
61+
"BaseEmbedding",
62+
"BaseModule",
63+
"BaseRegex",
64+
"BaseScorer",
65+
"DNNCScorer",
66+
"DescriptionScorer",
67+
"JinoosDecision",
68+
"KNNScorer",
69+
"LinearScorer",
70+
"LogregAimedEmbedding",
71+
"MLKnnScorer",
72+
"RerankScorer",
73+
"RetrievalAimedEmbedding",
74+
"SimpleRegex",
75+
"SklearnScorer",
76+
"ThresholdDecision",
77+
"TunableDecision",
78+
]

autointent/modules/scoring/_bert.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ class BertScorer(BaseScorer):
3131

3232
def __init__(
3333
self,
34-
model_config: HFModelConfig | str | dict[str, Any] | None = None,
34+
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
3535
num_train_epochs: int = 3,
3636
batch_size: int = 8,
3737
learning_rate: float = 5e-5,
3838
seed: int = 0,
3939
report_to: REPORTERS_NAMES | None = None, # type: ignore # noqa: PGH003
4040
) -> None:
41-
self.model_config = HFModelConfig.from_search_config(model_config)
41+
self.classification_model_config = HFModelConfig.from_search_config(classification_model_config)
4242
self.num_train_epochs = num_train_epochs
4343
self.batch_size = batch_size
4444
self.learning_rate = learning_rate
@@ -49,19 +49,19 @@ def __init__(
4949
def from_context(
5050
cls,
5151
context: Context,
52-
model_config: HFModelConfig | str | dict[str, Any] | None = None,
52+
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
5353
num_train_epochs: int = 3,
5454
batch_size: int = 8,
5555
learning_rate: float = 5e-5,
5656
seed: int = 0,
5757
) -> "BertScorer":
58-
if model_config is None:
59-
model_config = context.resolve_embedder()
58+
if classification_model_config is None:
59+
classification_model_config = context.resolve_embedder()
6060

6161
report_to = context.logging_config.report_to
6262

6363
return cls(
64-
model_config=model_config,
64+
classification_model_config=classification_model_config,
6565
num_train_epochs=num_train_epochs,
6666
batch_size=batch_size,
6767
learning_rate=learning_rate,
@@ -70,7 +70,7 @@ def from_context(
7070
)
7171

7272
def get_embedder_config(self) -> dict[str, Any]:
73-
return self.model_config.model_dump()
73+
return self.classification_model_config.model_dump()
7474

7575
def fit(
7676
self,
@@ -81,7 +81,7 @@ def fit(
8181
self.clear_cache()
8282
self._validate_task(labels)
8383

84-
model_name = self.model_config.model_name
84+
model_name = self.classification_model_config.model_name
8585
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
8686

8787
label2id = {i: i for i in range(self._n_classes)}
@@ -95,11 +95,11 @@ def fit(
9595
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
9696
)
9797

98-
use_cpu = self.model_config.device == "cpu"
98+
use_cpu = self.classification_model_config.device == "cpu"
9999

100100
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
101101
return self._tokenizer( # type: ignore[no-any-return]
102-
examples["text"], return_tensors="pt", **self.model_config.tokenizer_config.model_dump()
102+
examples["text"], return_tensors="pt", **self.classification_model_config.tokenizer_config.model_dump()
103103
)
104104

105105
dataset = Dataset.from_dict({"text": utterances, "labels": labels})
@@ -148,7 +148,9 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
148148
all_predictions = []
149149
for i in range(0, len(utterances), self.batch_size):
150150
batch = utterances[i : i + self.batch_size]
151-
inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())
151+
inputs = self._tokenizer(
152+
batch, return_tensors="pt", **self.classification_model_config.tokenizer_config.model_dump()
153+
)
152154
inputs = {k: v.to(device) for k, v in inputs.items()}
153155
with torch.no_grad():
154156
outputs = self._model(**inputs)

autointent/modules/scoring/_linear.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import numpy.typing as npt
7+
from pydantic import PositiveInt
78
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
89
from sklearn.multioutput import MultiOutputClassifier
910

@@ -22,7 +23,6 @@ class LinearScorer(BaseScorer):
2223
Args:
2324
embedder_config: Config of the embedder model
2425
cv: Number of cross-validation folds, defaults to 3
25-
n_jobs: Number of parallel jobs for cross-validation, defaults to None
2626
seed: Random seed for reproducibility, defaults to 0
2727
2828
Example:
@@ -72,18 +72,21 @@ def __init__(
7272
def from_context(
7373
cls,
7474
context: Context,
75+
cv: PositiveInt = 3,
7576
embedder_config: EmbedderConfig | str | None = None,
7677
) -> "LinearScorer":
7778
"""Create a LinearScorer instance using a Context object.
7879
7980
Args:
8081
context: Context containing configurations and utilities
82+
cv: Number of cross-validation folds, defaults to 3
8183
embedder_config: Config of the embedder, or None to use the best embedder
8284
"""
8385
if embedder_config is None:
8486
embedder_config = context.resolve_embedder()
8587

8688
return cls(
89+
cv=cv,
8790
embedder_config=embedder_config,
8891
)
8992

autointent/nodes/_node_optimizer.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,33 +11,14 @@
1111
import optuna
1212
import torch
1313
from optuna.trial import Trial
14-
from pydantic import BaseModel, Field
1514
from typing_extensions import assert_never
1615

1716
from autointent import Dataset
1817
from autointent.context import Context
1918
from autointent.custom_types import NodeType, SamplerType, SearchSpaceValidationMode
2019
from autointent.nodes.emissions_tracker import EmissionsTracker
2120
from autointent.nodes.info import NODES_INFO
22-
23-
24-
class ParamSpaceInt(BaseModel):
25-
"""Integer parameter search space configuration."""
26-
27-
low: int = Field(..., description="Lower boundary of the search space.")
28-
high: int = Field(..., description="Upper boundary of the search space.")
29-
step: int = Field(1, description="Step size for the search space.")
30-
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
31-
32-
33-
class ParamSpaceFloat(BaseModel):
34-
"""Float parameter search space configuration."""
35-
36-
low: float = Field(..., description="Lower boundary of the search space.")
37-
high: float = Field(..., description="Upper boundary of the search space.")
38-
step: float | None = Field(None, description="Step size for the search space (if applicable).")
39-
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
40-
21+
from autointent.schemas.node_validation import ParamSpaceFloat, ParamSpaceInt, SearchSpaceConfig
4122

4223
logger = logging.getLogger(__name__)
4324

@@ -277,7 +258,8 @@ def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidat
277258

278259
def validate_search_space(self, search_space: list[dict[str, Any]]) -> None:
279260
"""Check if search space is configured correctly."""
280-
for module_search_space in search_space:
261+
validated_search_space = SearchSpaceConfig(search_space).model_dump()
262+
for module_search_space in validated_search_space:
281263
module_search_space_no_optuna, module_name = self._reformat_search_space(deepcopy(module_search_space))
282264

283265
for params_combination in it.product(*module_search_space_no_optuna.values()):

0 commit comments

Comments
 (0)