Skip to content

Commit ef88dc2

Browse files
committed
remove multiclass/multilabel separation on modules dicts
1 parent 31e9812 commit ef88dc2

File tree

4 files changed

+11
-24
lines changed

4 files changed

+11
-24
lines changed

autointent/modules/__init__.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,38 +23,25 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
2323

2424
REGEX_MODULES: dict[str, type[BaseRegex]] = _create_modules_dict([Regex])
2525

26-
EMBEDDING_MODULES_MULTICLASS: dict[str, type[BaseEmbedding]] = _create_modules_dict(
26+
EMBEDDING_MODULES: dict[str, type[BaseEmbedding]] = _create_modules_dict(
2727
[RetrievalAimedEmbedding, LogregAimedEmbedding]
2828
)
2929

30-
EMBEDDING_MODULES_MULTILABEL: dict[str, type[BaseEmbedding]] = EMBEDDING_MODULES_MULTICLASS
31-
32-
SCORING_MODULES_MULTICLASS: dict[str, type[BaseScorer]] = _create_modules_dict(
30+
SCORING_MODULES: dict[str, type[BaseScorer]] = _create_modules_dict(
3331
[
3432
DNNCScorer,
3533
KNNScorer,
3634
LinearScorer,
3735
DescriptionScorer,
3836
RerankScorer,
3937
SklearnScorer,
40-
]
41-
)
42-
43-
SCORING_MODULES_MULTILABEL: dict[str, type[BaseScorer]] = _create_modules_dict(
44-
[
4538
MLKnnScorer,
46-
LinearScorer,
47-
DescriptionScorer,
48-
SklearnScorer,
49-
],
39+
]
5040
)
5141

52-
DECISION_MODULES_MULTICLASS: dict[str, type[BaseDecision]] = _create_modules_dict(
53-
[ArgmaxDecision, JinoosDecision, ThresholdDecision, TunableDecision],
42+
DECISION_MODULES: dict[str, type[BaseDecision]] = _create_modules_dict(
43+
[ArgmaxDecision, JinoosDecision, ThresholdDecision, TunableDecision, AdaptiveDecision],
5444
)
5545

56-
DECISION_MODULES_MULTILABEL: dict[str, type[BaseDecision]] = _create_modules_dict(
57-
[AdaptiveDecision, ThresholdDecision, TunableDecision],
58-
)
5946

6047
__all__ = [] # type: ignore[var-annotated]

autointent/nodes/info/_decision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from autointent.custom_types import NodeType
77
from autointent.metrics import DECISION_METRICS, DecisionMetricFn
8-
from autointent.modules import DECISION_MODULES_MULTICLASS, DECISION_MODULES_MULTILABEL
8+
from autointent.modules import DECISION_MODULES
99
from autointent.modules.abc import BaseDecision
1010

1111
from ._base import NodeInfo
@@ -17,7 +17,7 @@ class DecisionNodeInfo(NodeInfo):
1717
metrics_available: ClassVar[Mapping[str, DecisionMetricFn]] = DECISION_METRICS
1818

1919
modules_available: ClassVar[dict[str, type[BaseDecision]]] = (
20-
DECISION_MODULES_MULTICLASS | DECISION_MODULES_MULTILABEL
20+
DECISION_MODULES
2121
)
2222

2323
node_type = NodeType.decision

autointent/nodes/info/_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
RetrievalMetricFn,
1313
ScoringMetricFn,
1414
)
15-
from autointent.modules import EMBEDDING_MODULES_MULTICLASS, EMBEDDING_MODULES_MULTILABEL
15+
from autointent.modules import EMBEDDING_MODULES
1616
from autointent.modules.abc import BaseEmbedding
1717

1818
from ._base import NodeInfo
@@ -29,7 +29,7 @@ class EmbeddingNodeInfo(NodeInfo):
2929
)
3030

3131
modules_available: ClassVar[Mapping[str, type[BaseEmbedding]]] = (
32-
EMBEDDING_MODULES_MULTICLASS | EMBEDDING_MODULES_MULTILABEL
32+
EMBEDDING_MODULES
3333
)
3434

3535
node_type = NodeType.embedding

autointent/nodes/info/_scoring.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from autointent.custom_types import NodeType
77
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL, ScoringMetricFn
8-
from autointent.modules import SCORING_MODULES_MULTICLASS, SCORING_MODULES_MULTILABEL
8+
from autointent.modules import SCORING_MODULES
99
from autointent.modules.abc import BaseScorer
1010

1111
from ._base import NodeInfo
@@ -17,7 +17,7 @@ class ScoringNodeInfo(NodeInfo):
1717
metrics_available: ClassVar[Mapping[str, ScoringMetricFn]] = SCORING_METRICS_MULTICLASS | SCORING_METRICS_MULTILABEL
1818

1919
modules_available: ClassVar[Mapping[str, type[BaseScorer]]] = (
20-
SCORING_MODULES_MULTICLASS | SCORING_MODULES_MULTILABEL
20+
SCORING_MODULES
2121
)
2222

2323
node_type = NodeType.scoring

0 commit comments

Comments
 (0)