11import optuna
22
33from typing import Any , Callable , Dict , Sequence , Union , cast
4+ from framework3 import F1
45from framework3 .container import Container
5- from framework3 .base import BasePlugin , XYData
6+ from framework3 .base import BaseMetric , BasePlugin , XYData
67
78from rich import print
89
@@ -69,6 +70,7 @@ def __init__(
6970 pipeline : BaseFilter | None = None ,
7071 study_name : str | None = None ,
7172 storage : str | None = None ,
73+ scorer : BaseMetric = F1 (),
7274 ):
7375 """
7476 Initialize the OptunaOptimizer.
@@ -90,6 +92,7 @@ def __init__(
9092 self .n_trials = n_trials
9193 self .load_if_exists = load_if_exists
9294 self .reset_study = reset_study
95+ self .scorer = scorer
9396
9497 def optimize (self , pipeline : BaseFilter ):
9598 """
@@ -227,17 +230,12 @@ def matcher(k, v):
227230
228231 match pipeline .fit (x , y ):
229232 case None :
230- return float (
231- next (
232- iter (
233- pipeline .evaluate (
234- x , y , pipeline .predict (x )
235- ).values ()
236- )
237- )
238- )
233+ metrics = pipeline .evaluate (x , y , pipeline .predict (x ))
234+ return float (metrics .get (self .scorer .__class__ .__name__ , 0.0 ))
239235 case float () as loss :
240236 return loss
237+ case dict () as losses :
238+ return float (losses .get (self .scorer .__class__ .__name__ , 0.0 ))
241239 case _:
242240 raise ValueError ("Unsupported type in pipeline.fit" )
243241
0 commit comments