Skip to content

Commit abe8c2f

Browse files
voorhsSamoed
andauthored
Refactor/tunable threshold (#135)
* return tunable decision module to test search spaces * fix constructor * fix typing * bug fix * upd test * update schema on commit * fix doctest * fix schema * fix doctest * Update check-schema.yaml --------- Co-authored-by: Roman Solomatin <[email protected]> Co-authored-by: Roman Solomatin <[email protected]>
1 parent 063391a commit abe8c2f

File tree

10 files changed

+81
-33
lines changed

10 files changed

+81
-33
lines changed

.github/workflows/check-schema.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,23 @@ jobs:
3030
python -m scripts.generate_json_schema_config
3131
3232
- name: Check for changes in JSON Schema
33+
id: check_changes
3334
run: |
34-
if ! git diff docs/optimizer_config.schema.json; then
35-
echo "Error: docs/optimizer_config.schema.json has been modified after running the generator script."
35+
if git diff --quiet docs/optimizer_config.schema.json; then
36+
echo "No changes detected."
37+
echo "changed=false" >> $GITHUB_ENV
3638
else
37-
echo "No changes detected in docs/optimizer_config.schema.json."
38-
exit 0
39+
echo "Changes detected."
40+
echo "changed=true" >> $GITHUB_ENV
3941
fi
4042
4143
- name: Commit and push changes
44+
if: env.changed == 'true'
4245
env:
4346
GITHUB_TOKEN: ${{ github.token }}
4447
run: |
4548
git config --global user.name "github-actions[bot]"
4649
git config --global user.email "github-actions[bot]@users.noreply.github.com"
4750
git add docs/optimizer_config.schema.json
4851
git commit -m "Update optimizer_config.schema.json"
49-
git push
52+
git push

autointent/_pipeline/_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from autointent import Context, Dataset
1212
from autointent.configs import InferenceNodeConfig, LoggingConfig, VectorIndexConfig
1313
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType, ValidationScheme
14-
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
14+
from autointent.metrics import PREDICTION_METRICS
1515
from autointent.nodes import InferenceNode, NodeOptimizer
1616
from autointent.nodes.schemes import OptimizationConfig
1717
from autointent.utils import load_default_search_space, load_search_space
@@ -155,7 +155,7 @@ def fit(
155155
self._refit(context)
156156

157157
predictions = self.predict(context.data_handler.test_utterances())
158-
for metric_name, metric in PREDICTION_METRICS_MULTILABEL.items():
158+
for metric_name, metric in PREDICTION_METRICS.items():
159159
context.optimization_info.pipeline_metrics[metric_name] = metric(
160160
context.data_handler.test_labels(),
161161
predictions,

autointent/metrics/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,14 @@
8080
scoring_neg_ranking_loss,
8181
)
8282

83-
PREDICTION_METRICS_MULTICLASS: dict[str, DecisionMetricFn] = _funcs_to_dict(
83+
PREDICTION_METRICS: dict[str, DecisionMetricFn] = _funcs_to_dict(
8484
decision_accuracy,
8585
decision_f1,
8686
decision_precision,
8787
decision_recall,
8888
decision_roc_auc,
8989
)
9090

91-
PREDICTION_METRICS_MULTILABEL = PREDICTION_METRICS_MULTICLASS
92-
9391
REGEXP_METRICS = _funcs_to_dict(regexp_partial_accuracy, regexp_partial_precision)
9492

9593
METRIC_FN = DecisionMetricFn | RegexpMetricFn | RetrievalMetricFn | ScoringMetricFn

autointent/modules/abc/_decision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from autointent import Context
1010
from autointent.context.optimization_info import DecisionArtifact
1111
from autointent.custom_types import ListOfGenericLabels
12-
from autointent.metrics import PREDICTION_METRICS_MULTICLASS
12+
from autointent.metrics import PREDICTION_METRICS
1313
from autointent.modules.abc import Module
1414
from autointent.schemas import Tag
1515

@@ -53,7 +53,7 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:
5353

5454
val_labels, val_scores = get_decision_evaluation_data(context, "validation")
5555
decisions = self.predict(val_scores)
56-
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics}
56+
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS.items() if name in metrics}
5757
self._artifact = DecisionArtifact(labels=decisions)
5858
return self.score_metrics_ho((val_labels, decisions), chosen_metrics)
5959

@@ -72,7 +72,7 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
7272
msg = "No folded scores are found."
7373
raise RuntimeError(msg)
7474

75-
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics}
75+
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS.items() if name in metrics}
7676
metrics_values: dict[str, list[float]] = {name: [] for name in chosen_metrics}
7777
all_val_decisions = []
7878
for j in range(context.data_handler.n_folds):

autointent/modules/decision/_tunable.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tunable predictor module."""
22

3-
from typing import Any
3+
from typing import Any, Literal
44

55
import numpy as np
66
import numpy.typing as npt
@@ -11,12 +11,14 @@
1111
from autointent.context import Context
1212
from autointent.custom_types import ListOfGenericLabels
1313
from autointent.exceptions import MismatchNumClassesError
14-
from autointent.metrics import decision_f1
14+
from autointent.metrics import PREDICTION_METRICS, DecisionMetricFn
1515
from autointent.modules.abc import DecisionModule
1616
from autointent.schemas import Tag
1717

1818
from ._threshold import multiclass_predict, multilabel_predict
1919

20+
MetricType = Literal["decision_accuracy", "decision_f1", "decision_roc_auc", "decision_precision", "decision_recall"]
21+
2022

2123
class TunableDecision(DecisionModule):
2224
"""
@@ -40,7 +42,7 @@ class TunableDecision(DecisionModule):
4042
from autointent.modules import TunableDecision
4143
scores = np.array([[0.2, 0.8], [0.6, 0.4], [0.1, 0.9]])
4244
labels = [1, 0, 1]
43-
predictor = TunableDecision(n_trials=100, seed=42)
45+
predictor = TunableDecision(n_optuna_trials=100, seed=42)
4446
predictor.fit(scores, labels)
4547
test_scores = np.array([[0.3, 0.7], [0.5, 0.5]])
4648
predictions = predictor.predict(test_scores)
@@ -55,15 +57,15 @@ class TunableDecision(DecisionModule):
5557
.. testcode::
5658
5759
labels = [[1, 0], [0, 1], [1, 1]]
58-
predictor = TunableDecision(n_trials=100, seed=42)
60+
predictor = TunableDecision(n_optuna_trials=100, seed=42)
5961
predictor.fit(scores, labels)
6062
test_scores = np.array([[0.3, 0.7], [0.6, 0.4]])
6163
predictions = predictor.predict(test_scores)
6264
print(predictions)
6365
6466
.. testoutput::
6567
66-
[[1, 1], [1, 1]]
68+
[[1, 0], [1, 0]]
6769
6870
"""
6971

@@ -77,7 +79,8 @@ class TunableDecision(DecisionModule):
7779

7880
def __init__(
7981
self,
80-
n_trials: PositiveInt = 320,
82+
target_metric: MetricType = "decision_accuracy",
83+
n_optuna_trials: PositiveInt = 320,
8184
seed: int = 0,
8285
tags: list[Tag] | None = None,
8386
) -> None:
@@ -88,19 +91,27 @@ def __init__(
8891
:param seed: Seed
8992
:param tags: Tags
9093
"""
91-
self.n_trials = n_trials
94+
self.target_metric = target_metric
95+
self.n_optuna_trials = n_optuna_trials
9296
self.seed = seed
9397
self.tags = tags
9498

9599
@classmethod
96-
def from_context(cls, context: Context, n_trials: PositiveInt = 320) -> "TunableDecision":
100+
def from_context(
101+
cls, context: Context, target_metric: MetricType = "decision_accuracy", n_optuna_trials: PositiveInt = 320
102+
) -> "TunableDecision":
97103
"""
98104
Initialize from context.
99105
100106
:param context: Context
101107
:param n_trials: Number of trials
102108
"""
103-
return cls(n_trials=n_trials, seed=context.seed, tags=context.data_handler.tags)
109+
return cls(
110+
target_metric=target_metric,
111+
n_optuna_trials=n_optuna_trials,
112+
seed=context.seed,
113+
tags=context.data_handler.tags,
114+
)
104115

105116
def fit(
106117
self,
@@ -121,8 +132,10 @@ def fit(
121132
self.tags = tags
122133
self._validate_task(scores, labels)
123134

135+
metric_fn = PREDICTION_METRICS[self.target_metric]
136+
124137
thresh_optimizer = ThreshOptimizer(
125-
n_classes=self._n_classes, multilabel=self._multilabel, n_trials=self.n_trials
138+
metric_fn, n_classes=self._n_classes, multilabel=self._multilabel, n_trials=self.n_optuna_trials
126139
)
127140

128141
thresh_optimizer.fit(
@@ -150,14 +163,17 @@ def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels:
150163
class ThreshOptimizer:
151164
"""Threshold optimizer."""
152165

153-
def __init__(self, n_classes: int, multilabel: bool, n_trials: int | None = None) -> None:
166+
def __init__(
167+
self, metric_fn: DecisionMetricFn, n_classes: int, multilabel: bool, n_trials: int | None = None
168+
) -> None:
154169
"""
155170
Initialize threshold optimizer.
156171
157172
:param n_classes: Number of classes
158173
:param multilabel: Is multilabel
159174
:param n_trials: Number of trials
160175
"""
176+
self.metric_fn = metric_fn
161177
self.n_classes = n_classes
162178
self.multilabel = multilabel
163179
self.n_trials = n_trials if n_trials is not None else n_classes * 10
@@ -173,7 +189,7 @@ def objective(self, trial: Trial) -> float:
173189
y_pred = multilabel_predict(self.probas, thresholds, self.tags)
174190
else:
175191
y_pred = multiclass_predict(self.probas, thresholds)
176-
return decision_f1(self.labels, y_pred)
192+
return self.metric_fn(self.labels, y_pred)
177193

178194
def fit(
179195
self,

autointent/nodes/_nodes_info/_decision.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import ClassVar
55

66
from autointent.custom_types import NodeType
7-
from autointent.metrics import PREDICTION_METRICS_MULTICLASS, PREDICTION_METRICS_MULTILABEL, DecisionMetricFn
7+
from autointent.metrics import PREDICTION_METRICS, DecisionMetricFn
88
from autointent.modules import PREDICTION_MODULES_MULTICLASS, PREDICTION_MODULES_MULTILABEL
99
from autointent.modules.abc import DecisionModule
1010

@@ -14,9 +14,7 @@
1414
class DecisionNodeInfo(NodeInfo):
1515
"""Prediction node info."""
1616

17-
metrics_available: ClassVar[Mapping[str, DecisionMetricFn]] = (
18-
PREDICTION_METRICS_MULTICLASS | PREDICTION_METRICS_MULTILABEL
19-
)
17+
metrics_available: ClassVar[Mapping[str, DecisionMetricFn]] = PREDICTION_METRICS
2018

2119
modules_available: ClassVar[dict[str, type[DecisionModule]]] = (
2220
PREDICTION_MODULES_MULTICLASS | PREDICTION_MODULES_MULTILABEL

docs/optimizer_config.schema.json

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1425,6 +1425,37 @@
14251425
"type": "string"
14261426
},
14271427
"n_trials": {
1428+
"anyOf": [
1429+
{
1430+
"exclusiveMinimum": 0,
1431+
"type": "integer"
1432+
},
1433+
{
1434+
"type": "null"
1435+
}
1436+
],
1437+
"default": null,
1438+
"description": "Number of trials",
1439+
"title": "N Trials"
1440+
},
1441+
"target_metric": {
1442+
"default": [
1443+
"decision_accuracy"
1444+
],
1445+
"items": {
1446+
"enum": [
1447+
"decision_accuracy",
1448+
"decision_f1",
1449+
"decision_roc_auc",
1450+
"decision_precision",
1451+
"decision_recall"
1452+
],
1453+
"type": "string"
1454+
},
1455+
"title": "Target Metric",
1456+
"type": "array"
1457+
},
1458+
"n_optuna_trials": {
14281459
"anyOf": [
14291460
{
14301461
"items": {
@@ -1440,7 +1471,7 @@
14401471
"default": [
14411472
320
14421473
],
1443-
"title": "N Trials"
1474+
"title": "N Optuna Trials"
14441475
}
14451476
},
14461477
"required": [

tests/assets/configs/multiclass.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@
3838
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]
3939
- module_name: argmax
4040
- module_name: jinoos
41+
- module_name: tunable

tests/assets/configs/multilabel.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@
3333
- module_name: threshold
3434
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]
3535
- module_name: adaptive
36+
- module_name: tunable

tests/configs/test_decision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def valid_decision_config():
1717
{"module_name": "threshold", "thresh": [[0.5, 0.6]]},
1818
{
1919
"module_name": "tunable",
20-
"n_trials": [100],
20+
"n_optuna_trials": [100],
2121
},
2222
{"module_name": "adaptive", "search_space": [[0.5]]},
2323
],
@@ -40,7 +40,7 @@ def test_invalid_decision_config_missing_field():
4040
{
4141
"node_type": "decision",
4242
# Missing "target_metric"
43-
"search_space": [{"module_name": "tunable", "n_trials": [100]}],
43+
"search_space": [{"module_name": "tunable", "n_optuna_trials": [100]}],
4444
}
4545
]
4646

@@ -61,7 +61,7 @@ def test_invalid_decision_config_wrong_type():
6161
},
6262
{
6363
"module_name": "tunable",
64-
"n_trials": ["not_an_int"], # Should be an integer
64+
"n_optuna_trials": ["not_an_int"], # Should be an integer
6565
},
6666
],
6767
}

0 commit comments

Comments
 (0)