diff --git a/deeppavlov/configs/faq/tfidf_logreg_autofaq.json b/deeppavlov/configs/faq/tfidf_logreg_autofaq.json index 9e2516fceb..18f5067379 100644 --- a/deeppavlov/configs/faq/tfidf_logreg_autofaq.json +++ b/deeppavlov/configs/faq/tfidf_logreg_autofaq.json @@ -73,7 +73,7 @@ }, { "in": "y_pred_proba", - "out": "y_pred_ids", + "out": ["y_pred_ids", "y_pred_proba_max"], "class_name": "proba2labels", "max_proba": true }, @@ -85,7 +85,7 @@ ], "out": [ "y_pred_answers", - "y_pred_proba" + "y_pred_proba_max" ] }, "train": { diff --git a/deeppavlov/configs/faq/tfidf_logreg_en_faq.json b/deeppavlov/configs/faq/tfidf_logreg_en_faq.json index 04ab1f34d4..ed7399a71b 100644 --- a/deeppavlov/configs/faq/tfidf_logreg_en_faq.json +++ b/deeppavlov/configs/faq/tfidf_logreg_en_faq.json @@ -72,7 +72,7 @@ }, { "in": "y_pred_proba", - "out": "y_pred_ids", + "out": ["y_pred_ids", "y_pred_proba_max"], "class_name": "proba2labels", "max_proba": true }, @@ -84,7 +84,7 @@ ], "out": [ "y_pred_answers", - "y_pred_proba" + "y_pred_proba_max" ] }, "train": { diff --git a/deeppavlov/models/classifiers/proba2labels.py b/deeppavlov/models/classifiers/proba2labels.py index 29ab96ebc5..b7dffd5cfe 100644 --- a/deeppavlov/models/classifiers/proba2labels.py +++ b/deeppavlov/models/classifiers/proba2labels.py @@ -13,7 +13,7 @@ # limitations under the License. from logging import getLogger -from typing import List, Union +from typing import List, Union, Tuple import numpy as np @@ -54,7 +54,7 @@ def __init__(self, self.top_n = top_n def __call__(self, data: Union[np.ndarray, List[List[float]], List[List[int]]], - *args, **kwargs) -> Union[List[List[int]], List[int]]: + *args, **kwargs) -> Tuple[Union[List[List[int]], List[int]], Union[List[List[int]], List[int]]]: """ Process probabilities to labels @@ -63,14 +63,18 @@ def __call__(self, data: Union[np.ndarray, List[List[float]], List[List[int]]], Returns: list of labels (only label classification) or list of lists of labels (multi-label classification) + # add comment here """ if self.confident_threshold: - return [list(np.where(np.array(d) > self.confident_threshold)[0]) + labels = [list(np.where(np.array(d) > self.confident_threshold)[0]) for d in data] + return labels, [[d[l] for l in label] for d, label in zip(data, labels)] elif self.max_proba: - return [np.argmax(d) for d in data] + labels = [np.argmax(d) for d in data] + return labels, [[d[l] for l in label] for d, label in zip(data, labels)] elif self.top_n: - return [np.argsort(d)[::-1][:self.top_n] for d in data] + labels = [np.argsort(d)[::-1][:self.top_n] for d in data] + return labels, [[d[l] for l in label] for d, label in zip(data, labels)] else: raise ConfigError("Proba2Labels requires one of three arguments: bool `max_proba` or " "float `confident_threshold` for multi-label classification or"