-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy path_argmax.py
More file actions
90 lines (69 loc) · 2.32 KB
/
_argmax.py
File metadata and controls
90 lines (69 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Argmax decision module."""
import logging
from typing import Any
import numpy as np
import numpy.typing as npt
from autointent import Context
from autointent.custom_types import ListOfGenericLabels
from autointent.exceptions import MismatchNumClassesError
from autointent.modules.abc import DecisionModule
from autointent.schemas import Tag
logger = logging.getLogger(__name__)
class ArgmaxDecision(DecisionModule):
"""
Argmax decision module.
The ArgmaxDecision is a simple predictor that selects the class with the highest
score (argmax) for single-label classification tasks.
:ivar n_classes: Number of classes in the dataset.
Examples
--------
.. testcode::
from autointent.modules import ArgmaxDecision
import numpy as np
predictor = ArgmaxDecision()
train_scores = np.array([[0.2, 0.8], [0.7, 0.3]])
labels = [1, 0] # Single-label targets
predictor.fit(train_scores, labels)
test_scores = np.array([[0.1, 0.9], [0.6, 0.4]])
decisions = predictor.predict(test_scores)
print(decisions)
.. testoutput::
[1, 0]
"""
name = "argmax"
supports_oos = False
supports_multilabel = False
supports_multiclass = True
_n_classes: int
def __init__(self) -> None:
"""Init."""
@classmethod
def from_context(cls, context: Context) -> "ArgmaxDecision":
"""
Initialize form context.
:param context: Context
"""
return cls()
def fit(
self,
scores: npt.NDArray[Any],
labels: ListOfGenericLabels,
tags: list[Tag] | None = None,
) -> None:
"""
Argmax not fitting anything.
:param scores: Scores to fit
:param labels: Labels to fit
:param tags: Tags to fit
:raises WrongClassificationError: If the classification is wrong.
"""
self._validate_task(scores, labels)
def predict(self, scores: npt.NDArray[Any]) -> list[int]:
"""
Predict the argmax.
:param scores: Scores to predict
:raises MismatchNumClassesError: If the number of classes is invalid.
"""
if scores.shape[1] != self._n_classes:
raise MismatchNumClassesError
return np.argmax(scores, axis=1).tolist() # type: ignore[no-any-return]