diff --git a/.github/workflows/check-schema.yaml b/.github/workflows/check-schema.yaml index f47d1e330..9b87f134e 100644 --- a/.github/workflows/check-schema.yaml +++ b/.github/workflows/check-schema.yaml @@ -30,15 +30,18 @@ jobs: python -m scripts.generate_json_schema_config - name: Check for changes in JSON Schema + id: check_changes run: | - if ! git diff docs/optimizer_config.schema.json; then - echo "Error: docs/optimizer_config.schema.json has been modified after running the generator script." + if git diff --quiet docs/optimizer_config.schema.json; then + echo "No changes detected." + echo "changed=false" >> $GITHUB_ENV else - echo "No changes detected in docs/optimizer_config.schema.json." - exit 0 + echo "Changes detected." + echo "changed=true" >> $GITHUB_ENV fi - name: Commit and push changes + if: env.changed == 'true' env: GITHUB_TOKEN: ${{ github.token }} run: | @@ -46,4 +49,4 @@ jobs: git config --global user.email "github-actions[bot]@users.noreply.github.com" git add docs/optimizer_config.schema.json git commit -m "Update optimizer_config.schema.json" - git push \ No newline at end of file + git push diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 3085b90ba..6f44cc67d 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -11,7 +11,7 @@ from autointent import Context, Dataset from autointent.configs import InferenceNodeConfig, LoggingConfig, VectorIndexConfig from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType, ValidationScheme -from autointent.metrics import PREDICTION_METRICS_MULTILABEL +from autointent.metrics import PREDICTION_METRICS from autointent.nodes import InferenceNode, NodeOptimizer from autointent.nodes.schemes import OptimizationConfig from autointent.utils import load_default_search_space, load_search_space @@ -155,7 +155,7 @@ def fit( self._refit(context) predictions = self.predict(context.data_handler.test_utterances()) - for metric_name, metric in PREDICTION_METRICS_MULTILABEL.items(): + for metric_name, metric in PREDICTION_METRICS.items(): context.optimization_info.pipeline_metrics[metric_name] = metric( context.data_handler.test_labels(), predictions, diff --git a/autointent/metrics/__init__.py b/autointent/metrics/__init__.py index ca356c929..4ba830e01 100644 --- a/autointent/metrics/__init__.py +++ b/autointent/metrics/__init__.py @@ -80,7 +80,7 @@ scoring_neg_ranking_loss, ) -PREDICTION_METRICS_MULTICLASS: dict[str, DecisionMetricFn] = _funcs_to_dict( +PREDICTION_METRICS: dict[str, DecisionMetricFn] = _funcs_to_dict( decision_accuracy, decision_f1, decision_precision, @@ -88,8 +88,6 @@ decision_roc_auc, ) -PREDICTION_METRICS_MULTILABEL = PREDICTION_METRICS_MULTICLASS - REGEXP_METRICS = _funcs_to_dict(regexp_partial_accuracy, regexp_partial_precision) METRIC_FN = DecisionMetricFn | RegexpMetricFn | RetrievalMetricFn | ScoringMetricFn diff --git a/autointent/modules/abc/_decision.py b/autointent/modules/abc/_decision.py index dc128a11f..5d968b463 100644 --- a/autointent/modules/abc/_decision.py +++ b/autointent/modules/abc/_decision.py @@ -9,7 +9,7 @@ from autointent import Context from autointent.context.optimization_info import DecisionArtifact from autointent.custom_types import ListOfGenericLabels -from autointent.metrics import PREDICTION_METRICS_MULTICLASS +from autointent.metrics import PREDICTION_METRICS from autointent.modules.abc import Module from autointent.schemas import Tag @@ -53,7 +53,7 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]: val_labels, val_scores = get_decision_evaluation_data(context, "validation") decisions = self.predict(val_scores) - chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics} + chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS.items() if name in metrics} self._artifact = DecisionArtifact(labels=decisions) return self.score_metrics_ho((val_labels, decisions), chosen_metrics) @@ -72,7 +72,7 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: msg = "No folded scores are found." raise RuntimeError(msg) - chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics} + chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS.items() if name in metrics} metrics_values: dict[str, list[float]] = {name: [] for name in chosen_metrics} all_val_decisions = [] for j in range(context.data_handler.n_folds): diff --git a/autointent/modules/decision/_tunable.py b/autointent/modules/decision/_tunable.py index f5bfdfb29..9a4b914a9 100644 --- a/autointent/modules/decision/_tunable.py +++ b/autointent/modules/decision/_tunable.py @@ -1,6 +1,6 @@ """Tunable predictor module.""" -from typing import Any +from typing import Any, Literal import numpy as np import numpy.typing as npt @@ -11,12 +11,14 @@ from autointent.context import Context from autointent.custom_types import ListOfGenericLabels from autointent.exceptions import MismatchNumClassesError -from autointent.metrics import decision_f1 +from autointent.metrics import PREDICTION_METRICS, DecisionMetricFn from autointent.modules.abc import DecisionModule from autointent.schemas import Tag from ._threshold import multiclass_predict, multilabel_predict +MetricType = Literal["decision_accuracy", "decision_f1", "decision_roc_auc", "decision_precision", "decision_recall"] + class TunableDecision(DecisionModule): """ @@ -40,7 +42,7 @@ class TunableDecision(DecisionModule): from autointent.modules import TunableDecision scores = np.array([[0.2, 0.8], [0.6, 0.4], [0.1, 0.9]]) labels = [1, 0, 1] - predictor = TunableDecision(n_trials=100, seed=42) + predictor = TunableDecision(n_optuna_trials=100, seed=42) predictor.fit(scores, labels) test_scores = np.array([[0.3, 0.7], [0.5, 0.5]]) predictions = predictor.predict(test_scores) @@ -55,7 +57,7 @@ class TunableDecision(DecisionModule): .. testcode:: labels = [[1, 0], [0, 1], [1, 1]] - predictor = TunableDecision(n_trials=100, seed=42) + predictor = TunableDecision(n_optuna_trials=100, seed=42) predictor.fit(scores, labels) test_scores = np.array([[0.3, 0.7], [0.6, 0.4]]) predictions = predictor.predict(test_scores) @@ -63,7 +65,7 @@ class TunableDecision(DecisionModule): .. testoutput:: - [[1, 1], [1, 1]] + [[1, 0], [1, 0]] """ @@ -77,7 +79,8 @@ class TunableDecision(DecisionModule): def __init__( self, - n_trials: PositiveInt = 320, + target_metric: MetricType = "decision_accuracy", + n_optuna_trials: PositiveInt = 320, seed: int = 0, tags: list[Tag] | None = None, ) -> None: @@ -88,19 +91,27 @@ def __init__( :param seed: Seed :param tags: Tags """ - self.n_trials = n_trials + self.target_metric = target_metric + self.n_optuna_trials = n_optuna_trials self.seed = seed self.tags = tags @classmethod - def from_context(cls, context: Context, n_trials: PositiveInt = 320) -> "TunableDecision": + def from_context( + cls, context: Context, target_metric: MetricType = "decision_accuracy", n_optuna_trials: PositiveInt = 320 + ) -> "TunableDecision": """ Initialize from context. :param context: Context :param n_trials: Number of trials """ - return cls(n_trials=n_trials, seed=context.seed, tags=context.data_handler.tags) + return cls( + target_metric=target_metric, + n_optuna_trials=n_optuna_trials, + seed=context.seed, + tags=context.data_handler.tags, + ) def fit( self, @@ -121,8 +132,10 @@ def fit( self.tags = tags self._validate_task(scores, labels) + metric_fn = PREDICTION_METRICS[self.target_metric] + thresh_optimizer = ThreshOptimizer( - n_classes=self._n_classes, multilabel=self._multilabel, n_trials=self.n_trials + metric_fn, n_classes=self._n_classes, multilabel=self._multilabel, n_trials=self.n_optuna_trials ) thresh_optimizer.fit( @@ -150,7 +163,9 @@ def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels: class ThreshOptimizer: """Threshold optimizer.""" - def __init__(self, n_classes: int, multilabel: bool, n_trials: int | None = None) -> None: + def __init__( + self, metric_fn: DecisionMetricFn, n_classes: int, multilabel: bool, n_trials: int | None = None + ) -> None: """ Initialize threshold optimizer. @@ -158,6 +173,7 @@ def __init__(self, n_classes: int, multilabel: bool, n_trials: int | None = None :param multilabel: Is multilabel :param n_trials: Number of trials """ + self.metric_fn = metric_fn self.n_classes = n_classes self.multilabel = multilabel self.n_trials = n_trials if n_trials is not None else n_classes * 10 @@ -173,7 +189,7 @@ def objective(self, trial: Trial) -> float: y_pred = multilabel_predict(self.probas, thresholds, self.tags) else: y_pred = multiclass_predict(self.probas, thresholds) - return decision_f1(self.labels, y_pred) + return self.metric_fn(self.labels, y_pred) def fit( self, diff --git a/autointent/nodes/_nodes_info/_decision.py b/autointent/nodes/_nodes_info/_decision.py index 3933abfaa..8843fc551 100644 --- a/autointent/nodes/_nodes_info/_decision.py +++ b/autointent/nodes/_nodes_info/_decision.py @@ -4,7 +4,7 @@ from typing import ClassVar from autointent.custom_types import NodeType -from autointent.metrics import PREDICTION_METRICS_MULTICLASS, PREDICTION_METRICS_MULTILABEL, DecisionMetricFn +from autointent.metrics import PREDICTION_METRICS, DecisionMetricFn from autointent.modules import PREDICTION_MODULES_MULTICLASS, PREDICTION_MODULES_MULTILABEL from autointent.modules.abc import DecisionModule @@ -14,9 +14,7 @@ class DecisionNodeInfo(NodeInfo): """Prediction node info.""" - metrics_available: ClassVar[Mapping[str, DecisionMetricFn]] = ( - PREDICTION_METRICS_MULTICLASS | PREDICTION_METRICS_MULTILABEL - ) + metrics_available: ClassVar[Mapping[str, DecisionMetricFn]] = PREDICTION_METRICS modules_available: ClassVar[dict[str, type[DecisionModule]]] = ( PREDICTION_MODULES_MULTICLASS | PREDICTION_MODULES_MULTILABEL diff --git a/docs/optimizer_config.schema.json b/docs/optimizer_config.schema.json index f51fa93a3..603b16424 100644 --- a/docs/optimizer_config.schema.json +++ b/docs/optimizer_config.schema.json @@ -1425,6 +1425,37 @@ "type": "string" }, "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "target_metric": { + "default": [ + "decision_accuracy" + ], + "items": { + "enum": [ + "decision_accuracy", + "decision_f1", + "decision_roc_auc", + "decision_precision", + "decision_recall" + ], + "type": "string" + }, + "title": "Target Metric", + "type": "array" + }, + "n_optuna_trials": { "anyOf": [ { "items": { @@ -1440,7 +1471,7 @@ "default": [ 320 ], - "title": "N Trials" + "title": "N Optuna Trials" } }, "required": [ diff --git a/tests/assets/configs/multiclass.yaml b/tests/assets/configs/multiclass.yaml index 3fbf8948c..e33e8e559 100644 --- a/tests/assets/configs/multiclass.yaml +++ b/tests/assets/configs/multiclass.yaml @@ -38,3 +38,4 @@ thresh: [0.5, [0.5, 0.5, 0.5, 0.5]] - module_name: argmax - module_name: jinoos + - module_name: tunable \ No newline at end of file diff --git a/tests/assets/configs/multilabel.yaml b/tests/assets/configs/multilabel.yaml index 91742358a..159501a53 100644 --- a/tests/assets/configs/multilabel.yaml +++ b/tests/assets/configs/multilabel.yaml @@ -33,3 +33,4 @@ - module_name: threshold thresh: [0.5, [0.5, 0.5, 0.5, 0.5]] - module_name: adaptive + - module_name: tunable diff --git a/tests/configs/test_decision.py b/tests/configs/test_decision.py index 201b9b7dc..b7a580cdb 100644 --- a/tests/configs/test_decision.py +++ b/tests/configs/test_decision.py @@ -17,7 +17,7 @@ def valid_decision_config(): {"module_name": "threshold", "thresh": [[0.5, 0.6]]}, { "module_name": "tunable", - "n_trials": [100], + "n_optuna_trials": [100], }, {"module_name": "adaptive", "search_space": [[0.5]]}, ], @@ -40,7 +40,7 @@ def test_invalid_decision_config_missing_field(): { "node_type": "decision", # Missing "target_metric" - "search_space": [{"module_name": "tunable", "n_trials": [100]}], + "search_space": [{"module_name": "tunable", "n_optuna_trials": [100]}], } ] @@ -61,7 +61,7 @@ def test_invalid_decision_config_wrong_type(): }, { "module_name": "tunable", - "n_trials": ["not_an_int"], # Should be an integer + "n_optuna_trials": ["not_an_int"], # Should be an integer }, ], }