@@ -48,8 +48,8 @@ def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:
4848 :param split: Target split
4949 :return: Computed metrics value for the test set or error code of metrics
5050 """
51- train_scores , train_labels = self .get_train_data (context )
52- self .fit (train_scores , train_labels , context . data_handler . tags )
51+ train_scores , train_labels , tags = self .get_train_data (context )
52+ self .fit (train_scores , train_labels , tags )
5353
5454 val_labels , val_scores = get_decision_evaluation_data (context , "validation" )
5555 decisions = self .predict (val_scores )
@@ -73,22 +73,22 @@ def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
7373 raise RuntimeError (msg )
7474
7575 chosen_metrics = {name : fn for name , fn in PREDICTION_METRICS_MULTICLASS .items () if name in metrics }
76- metrics_values = {name : [] for name in chosen_metrics }
76+ metrics_values : dict [ str , list [ float ]] = {name : [] for name in chosen_metrics }
7777 all_val_decisions = []
7878 for j in range (context .data_handler .n_folds ):
7979 val_labels = labels [j ]
8080 val_scores = scores [j ]
8181 train_folds = [i for i in range (context .data_handler .n_folds ) if i != j ]
8282 train_labels = [ut for i_fold in train_folds for ut in labels [i_fold ]]
8383 train_scores = [ut for i_fold in train_folds for ut in scores [i_fold ]]
84- self .fit (train_scores , train_labels , context .data_handler .tags )
84+ self .fit (train_scores , train_labels , context .data_handler .tags ) # type: ignore[arg-type]
8585 val_decisions = self .predict (val_scores )
8686 for name , fn in chosen_metrics .items ():
8787 metrics_values [name ].append (fn (val_labels , val_decisions ))
8888 all_val_decisions .append (val_decisions )
8989
9090 self ._artifact = DecisionArtifact (labels = [pred for pred_list in all_val_decisions for pred in pred_list ])
91- return {name : np .mean (values_list ) for name , values_list in metrics_values .items ()}
91+ return {name : float ( np .mean (values_list ) ) for name , values_list in metrics_values .items ()}
9292
9393 def get_assets (self ) -> DecisionArtifact :
9494 """Return useful assets that represent intermediate data into context."""
0 commit comments