@@ -63,27 +63,18 @@ def get_assets(self) -> DecisionArtifact:
6363 def clear_cache (self ) -> None :
6464 """Clear cache."""
6565
66- def _validate_inputs (self , scores : npt .NDArray [Any ], labels : ListOfGenericLabels ) -> tuple [int , bool , bool ]:
67- """
68- Sanity check if labels and scores are valid to be a training data for decision module.
69-
70- :param scores: training scores
71- :param labels: training labels
72- :return: number of classes, indicator if it's a multi-label task,
73- indicator if data contains oos samples
74- """
75- n_classes , multilabel , contains_oos_samples = super ()._get_task_specs (labels )
76-
77- if n_classes != scores .shape [1 ]:
66+ def _validate_task (self , scores : npt .NDArray [Any ], labels : ListOfGenericLabels ) -> None :
67+ self ._n_classes , self ._multilabel , self ._oos = self ._get_task_specs (labels )
68+ self ._validate_multilabel (self ._multilabel )
69+ self ._validate_oos (self ._oos , raise_error = False )
70+ if self ._n_classes != scores .shape [1 ]:
7871 msg = (
7972 "There is a mismatch between provided labels and scores. "
80- f"Labels contains { n_classes } classes, but scores contain "
73+ f"Labels contains { self . _n_classes } classes, but scores contain "
8174 f"probabilities for { scores .shape [1 ]} classes."
8275 )
8376 raise ValueError (msg )
8477
85- return n_classes , multilabel , contains_oos_samples
86-
8778
8879def get_decision_evaluation_data (
8980 context : Context ,
0 commit comments