11"""Tunable predictor module."""
22
3- from typing import Any
3+ from typing import Any , Literal
44
55import numpy as np
66import numpy .typing as npt
1111from autointent .context import Context
1212from autointent .custom_types import ListOfGenericLabels
1313from autointent .exceptions import MismatchNumClassesError
14- from autointent .metrics import decision_f1
14+ from autointent .metrics import PREDICTION_METRICS , DecisionMetricFn
1515from autointent .modules .abc import DecisionModule
1616from autointent .schemas import Tag
1717
1818from ._threshold import multiclass_predict , multilabel_predict
1919
20+ MetricType = Literal ["decision_accuracy" , "decision_f1" , "decision_roc_auc" , "decision_precision" , "decision_recall" ]
21+
2022
2123class TunableDecision (DecisionModule ):
2224 """
@@ -40,7 +42,7 @@ class TunableDecision(DecisionModule):
4042 from autointent.modules import TunableDecision
4143 scores = np.array([[0.2, 0.8], [0.6, 0.4], [0.1, 0.9]])
4244 labels = [1, 0, 1]
43- predictor = TunableDecision(n_trials =100, seed=42)
45+ predictor = TunableDecision(n_optuna_trials =100, seed=42)
4446 predictor.fit(scores, labels)
4547 test_scores = np.array([[0.3, 0.7], [0.5, 0.5]])
4648 predictions = predictor.predict(test_scores)
@@ -55,15 +57,15 @@ class TunableDecision(DecisionModule):
5557 .. testcode::
5658
5759 labels = [[1, 0], [0, 1], [1, 1]]
58- predictor = TunableDecision(n_trials =100, seed=42)
60+ predictor = TunableDecision(n_optuna_trials =100, seed=42)
5961 predictor.fit(scores, labels)
6062 test_scores = np.array([[0.3, 0.7], [0.6, 0.4]])
6163 predictions = predictor.predict(test_scores)
6264 print(predictions)
6365
6466 .. testoutput::
6567
66- [[1, 1 ], [1, 1 ]]
68+ [[1, 0 ], [1, 0 ]]
6769
6870 """
6971
@@ -77,7 +79,8 @@ class TunableDecision(DecisionModule):
7779
7880 def __init__ (
7981 self ,
80- n_trials : PositiveInt = 320 ,
82+ target_metric : MetricType = "decision_accuracy" ,
83+ n_optuna_trials : PositiveInt = 320 ,
8184 seed : int = 0 ,
8285 tags : list [Tag ] | None = None ,
8386 ) -> None :
@@ -88,19 +91,27 @@ def __init__(
8891 :param seed: Seed
8992 :param tags: Tags
9093 """
91- self .n_trials = n_trials
94+ self .target_metric = target_metric
95+ self .n_optuna_trials = n_optuna_trials
9296 self .seed = seed
9397 self .tags = tags
9498
9599 @classmethod
96- def from_context (cls , context : Context , n_trials : PositiveInt = 320 ) -> "TunableDecision" :
100+ def from_context (
101+ cls , context : Context , target_metric : MetricType = "decision_accuracy" , n_optuna_trials : PositiveInt = 320
102+ ) -> "TunableDecision" :
97103 """
98104 Initialize from context.
99105
100106 :param context: Context
101107 :param n_trials: Number of trials
102108 """
103- return cls (n_trials = n_trials , seed = context .seed , tags = context .data_handler .tags )
109+ return cls (
110+ target_metric = target_metric ,
111+ n_optuna_trials = n_optuna_trials ,
112+ seed = context .seed ,
113+ tags = context .data_handler .tags ,
114+ )
104115
105116 def fit (
106117 self ,
@@ -121,8 +132,10 @@ def fit(
121132 self .tags = tags
122133 self ._validate_task (scores , labels )
123134
135+ metric_fn = PREDICTION_METRICS [self .target_metric ]
136+
124137 thresh_optimizer = ThreshOptimizer (
125- n_classes = self ._n_classes , multilabel = self ._multilabel , n_trials = self .n_trials
138+ metric_fn , n_classes = self ._n_classes , multilabel = self ._multilabel , n_trials = self .n_optuna_trials
126139 )
127140
128141 thresh_optimizer .fit (
@@ -150,14 +163,17 @@ def predict(self, scores: npt.NDArray[Any]) -> ListOfGenericLabels:
150163class ThreshOptimizer :
151164 """Threshold optimizer."""
152165
153- def __init__ (self , n_classes : int , multilabel : bool , n_trials : int | None = None ) -> None :
166+ def __init__ (
167+ self , metric_fn : DecisionMetricFn , n_classes : int , multilabel : bool , n_trials : int | None = None
168+ ) -> None :
154169 """
155170 Initialize threshold optimizer.
156171
157172 :param n_classes: Number of classes
158173 :param multilabel: Is multilabel
159174 :param n_trials: Number of trials
160175 """
176+ self .metric_fn = metric_fn
161177 self .n_classes = n_classes
162178 self .multilabel = multilabel
163179 self .n_trials = n_trials if n_trials is not None else n_classes * 10
@@ -173,7 +189,7 @@ def objective(self, trial: Trial) -> float:
173189 y_pred = multilabel_predict (self .probas , thresholds , self .tags )
174190 else :
175191 y_pred = multiclass_predict (self .probas , thresholds )
176- return decision_f1 (self .labels , y_pred )
192+ return self . metric_fn (self .labels , y_pred )
177193
178194 def fit (
179195 self ,
0 commit comments