Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions .github/workflows/check-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,23 @@ jobs:
python -m scripts.generate_json_schema_config

- name: Check for changes in JSON Schema
id: check_changes
run: |
if ! git diff docs/optimizer_config.schema.json; then
echo "Error: docs/optimizer_config.schema.json has been modified after running the generator script."
if git diff --quiet docs/optimizer_config.schema.json; then
echo "No changes detected."
echo "changed=false" >> $GITHUB_ENV
else
echo "No changes detected in docs/optimizer_config.schema.json."
exit 0
echo "Changes detected."
echo "changed=true" >> $GITHUB_ENV
fi

- name: Commit and push changes
if: env.changed == 'true'
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
git config --global user.name "github-actions[bot]"
git config --global user.email "github-actions[bot]@users.noreply.github.com"
git add docs/optimizer_config.schema.json
git commit -m "Update optimizer_config.schema.json"
git push
git push
4 changes: 2 additions & 2 deletions autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from autointent import Context, Dataset
from autointent.configs import InferenceNodeConfig, LoggingConfig, VectorIndexConfig
from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType, ValidationScheme
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
from autointent.metrics import PREDICTION_METRICS
from autointent.nodes import InferenceNode, NodeOptimizer
from autointent.nodes.schemes import OptimizationConfig
from autointent.utils import load_default_search_space, load_search_space
Expand Down Expand Up @@ -155,7 +155,7 @@ def fit(
self._refit(context)

predictions = self.predict(context.data_handler.test_utterances())
for metric_name, metric in PREDICTION_METRICS_MULTILABEL.items():
for metric_name, metric in PREDICTION_METRICS.items():
context.optimization_info.pipeline_metrics[metric_name] = metric(
context.data_handler.test_labels(),
predictions,
Expand Down
4 changes: 1 addition & 3 deletions autointent/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,14 @@
scoring_neg_ranking_loss,
)

PREDICTION_METRICS_MULTICLASS: dict[str, DecisionMetricFn] = _funcs_to_dict(
PREDICTION_METRICS: dict[str, DecisionMetricFn] = _funcs_to_dict(
decision_accuracy,
decision_f1,
decision_precision,
decision_recall,
decision_roc_auc,
)

PREDICTION_METRICS_MULTILABEL = PREDICTION_METRICS_MULTICLASS

REGEXP_METRICS = _funcs_to_dict(regexp_partial_accuracy, regexp_partial_precision)

METRIC_FN = DecisionMetricFn | RegexpMetricFn | RetrievalMetricFn | ScoringMetricFn
Expand Down
6 changes: 3 additions & 3 deletions autointent/modules/abc/_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from autointent import Context
from autointent.context.optimization_info import DecisionArtifact
from autointent.custom_types import ListOfGenericLabels
from autointent.metrics import PREDICTION_METRICS_MULTICLASS
from autointent.metrics import PREDICTION_METRICS
from autointent.modules.abc import Module
from autointent.schemas import Tag

Expand Down Expand Up @@ -53,7 +53,7 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:

val_labels, val_scores = get_decision_evaluation_data(context, "validation")
decisions = self.predict(val_scores)
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics}
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS.items() if name in metrics}
self._artifact = DecisionArtifact(labels=decisions)
return self.score_metrics_ho((val_labels, decisions), chosen_metrics)

Expand All @@ -72,7 +72,7 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
msg = "No folded scores are found."
raise RuntimeError(msg)

chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics}
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS.items() if name in metrics}
metrics_values: dict[str, list[float]] = {name: [] for name in chosen_metrics}
all_val_decisions = []
for j in range(context.data_handler.n_folds):
Expand Down
40 changes: 28 additions & 12 deletions autointent/modules/decision/_tunable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tunable predictor module."""

from typing import Any
from typing import Any, Literal

import numpy as np
import numpy.typing as npt
Expand All @@ -11,12 +11,14 @@
from autointent.context import Context
from autointent.custom_types import ListOfGenericLabels
from autointent.exceptions import MismatchNumClassesError
from autointent.metrics import decision_f1
from autointent.metrics import PREDICTION_METRICS, DecisionMetricFn
from autointent.modules.abc import DecisionModule
from autointent.schemas import Tag

from ._threshold import multiclass_predict, multilabel_predict

MetricType = Literal["decision_accuracy", "decision_f1", "decision_roc_auc", "decision_precision", "decision_recall"]


class TunableDecision(DecisionModule):
"""
Expand All @@ -40,7 +42,7 @@ class TunableDecision(DecisionModule):
from autointent.modules import TunableDecision
scores = np.array([[0.2, 0.8], [0.6, 0.4], [0.1, 0.9]])
labels = [1, 0, 1]
predictor = TunableDecision(n_trials=100, seed=42)
predictor = TunableDecision(n_optuna_trials=100, seed=42)
predictor.fit(scores, labels)
test_scores = np.array([[0.3, 0.7], [0.5, 0.5]])
predictions = predictor.predict(test_scores)
Expand All @@ -55,15 +57,15 @@ class TunableDecision(DecisionModule):
.. testcode::

labels = [[1, 0], [0, 1], [1, 1]]
predictor = TunableDecision(n_trials=100, seed=42)
predictor = TunableDecision(n_optuna_trials=100, seed=42)
predictor.fit(scores, labels)
test_scores = np.array([[0.3, 0.7], [0.6, 0.4]])
predictions = predictor.predict(test_scores)
print(predictions)

.. testoutput::

[[1, 1], [1, 1]]
[[1, 0], [1, 0]]

"""

Expand All @@ -77,7 +79,8 @@ class TunableDecision(DecisionModule):

def __init__(
self,
n_trials: PositiveInt = 320,
target_metric: MetricType = "decision_accuracy",
n_optuna_trials: PositiveInt = 320,
seed: int = 0,
tags: list[Tag] | None = None,
) -> None:
Expand All @@ -88,19 +91,27 @@ def __init__(
:param seed: Seed
:param tags: Tags
"""
self.n_trials = n_trials
self.target_metric = target_metric
self.n_optuna_trials = n_optuna_trials
self.seed = seed
self.tags = tags

@classmethod
def from_context(cls, context: Context, n_trials: PositiveInt = 320) -> "TunableDecision":
def from_context(
cls, context: Context, target_metric: MetricType = "decision_accuracy", n_optuna_trials: PositiveInt = 320
) -> "TunableDecision":
"""
Initialize from context.

:param context: Context
:param n_trials: Number of trials
"""
return cls(n_trials=n_trials, seed=context.seed, tags=context.data_handler.tags)
return cls(
target_metric=target_metric,
n_optuna_trials=n_optuna_trials,
seed=context.seed,
tags=context.data_handler.tags,
)

def fit(
self,
Expand All @@ -121,8 +132,10 @@ def fit(
self.tags = tags
self._validate_task(scores, labels)

metric_fn = PREDICTION_METRICS[self.target_metric]

thresh_optimizer = ThreshOptimizer(
n_classes=self._n_classes, multilabel=self._multilabel, n_trials=self.n_trials
metric_fn, n_classes=self._n_classes, multilabel=self._multilabel, n_trials=self.n_optuna_trials
)

thresh_optimizer.fit(
Expand Down Expand Up @@ -150,14 +163,17 @@ def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels:
class ThreshOptimizer:
"""Threshold optimizer."""

def __init__(self, n_classes: int, multilabel: bool, n_trials: int | None = None) -> None:
def __init__(
self, metric_fn: DecisionMetricFn, n_classes: int, multilabel: bool, n_trials: int | None = None
) -> None:
"""
Initialize threshold optimizer.

:param n_classes: Number of classes
:param multilabel: Is multilabel
:param n_trials: Number of trials
"""
self.metric_fn = metric_fn
self.n_classes = n_classes
self.multilabel = multilabel
self.n_trials = n_trials if n_trials is not None else n_classes * 10
Expand All @@ -173,7 +189,7 @@ def objective(self, trial: Trial) -> float:
y_pred = multilabel_predict(self.probas, thresholds, self.tags)
else:
y_pred = multiclass_predict(self.probas, thresholds)
return decision_f1(self.labels, y_pred)
return self.metric_fn(self.labels, y_pred)

def fit(
self,
Expand Down
6 changes: 2 additions & 4 deletions autointent/nodes/_nodes_info/_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import ClassVar

from autointent.custom_types import NodeType
from autointent.metrics import PREDICTION_METRICS_MULTICLASS, PREDICTION_METRICS_MULTILABEL, DecisionMetricFn
from autointent.metrics import PREDICTION_METRICS, DecisionMetricFn
from autointent.modules import PREDICTION_MODULES_MULTICLASS, PREDICTION_MODULES_MULTILABEL
from autointent.modules.abc import DecisionModule

Expand All @@ -14,9 +14,7 @@
class DecisionNodeInfo(NodeInfo):
"""Prediction node info."""

metrics_available: ClassVar[Mapping[str, DecisionMetricFn]] = (
PREDICTION_METRICS_MULTICLASS | PREDICTION_METRICS_MULTILABEL
)
metrics_available: ClassVar[Mapping[str, DecisionMetricFn]] = PREDICTION_METRICS

modules_available: ClassVar[dict[str, type[DecisionModule]]] = (
PREDICTION_MODULES_MULTICLASS | PREDICTION_MODULES_MULTILABEL
Expand Down
33 changes: 32 additions & 1 deletion docs/optimizer_config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,37 @@
"type": "string"
},
"n_trials": {
"anyOf": [
{
"exclusiveMinimum": 0,
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"description": "Number of trials",
"title": "N Trials"
},
"target_metric": {
"default": [
"decision_accuracy"
],
"items": {
"enum": [
"decision_accuracy",
"decision_f1",
"decision_roc_auc",
"decision_precision",
"decision_recall"
],
"type": "string"
},
"title": "Target Metric",
"type": "array"
},
"n_optuna_trials": {
"anyOf": [
{
"items": {
Expand All @@ -1440,7 +1471,7 @@
"default": [
320
],
"title": "N Trials"
"title": "N Optuna Trials"
}
},
"required": [
Expand Down
1 change: 1 addition & 0 deletions tests/assets/configs/multiclass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]
- module_name: argmax
- module_name: jinoos
- module_name: tunable
1 change: 1 addition & 0 deletions tests/assets/configs/multilabel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@
- module_name: threshold
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]
- module_name: adaptive
- module_name: tunable
6 changes: 3 additions & 3 deletions tests/configs/test_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def valid_decision_config():
{"module_name": "threshold", "thresh": [[0.5, 0.6]]},
{
"module_name": "tunable",
"n_trials": [100],
"n_optuna_trials": [100],
},
{"module_name": "adaptive", "search_space": [[0.5]]},
],
Expand All @@ -40,7 +40,7 @@ def test_invalid_decision_config_missing_field():
{
"node_type": "decision",
# Missing "target_metric"
"search_space": [{"module_name": "tunable", "n_trials": [100]}],
"search_space": [{"module_name": "tunable", "n_optuna_trials": [100]}],
}
]

Expand All @@ -61,7 +61,7 @@ def test_invalid_decision_config_wrong_type():
},
{
"module_name": "tunable",
"n_trials": ["not_an_int"], # Should be an integer
"n_optuna_trials": ["not_an_int"], # Should be an integer
},
],
}
Expand Down