|
| 1 | +from typing import TypeVar |
| 2 | + |
1 | 3 | from .base import Module |
2 | 4 | from .prediction import ( |
3 | 5 | ArgmaxPredictor, |
|
10 | 12 | from .retrieval import RetrievalModule, VectorDBModule |
11 | 13 | from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, ScoringModule |
12 | 14 |
|
13 | | -RETRIEVAL_MODULES_MULTICLASS: dict[str, type[Module]] = { |
14 | | - "vector_db": VectorDBModule, |
15 | | -} |
| 15 | +T = TypeVar("T", bound=Module) |
| 16 | + |
| 17 | + |
| 18 | +def create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]: |
| 19 | + return {module.name: module for module in modules} |
| 20 | + |
| 21 | + |
| 22 | +RETRIEVAL_MODULES_MULTICLASS: dict[str, type[Module]] = create_modules_dict([VectorDBModule]) |
16 | 23 |
|
17 | 24 | RETRIEVAL_MODULES_MULTILABEL = RETRIEVAL_MODULES_MULTICLASS |
18 | 25 |
|
19 | | -SCORING_MODULES_MULTICLASS: dict[str, type[ScoringModule]] = { |
20 | | - "dnnc": DNNCScorer, |
21 | | - "knn": KNNScorer, |
22 | | - "linear": LinearScorer, |
23 | | - "description": DescriptionScorer, |
24 | | -} |
| 26 | +SCORING_MODULES_MULTICLASS: dict[str, type[ScoringModule]] = create_modules_dict( |
| 27 | + [DNNCScorer, KNNScorer, LinearScorer, DescriptionScorer] |
| 28 | +) |
| 29 | + |
| 30 | +SCORING_MODULES_MULTILABEL: dict[str, type[ScoringModule]] = create_modules_dict( |
| 31 | + [MLKnnScorer, LinearScorer, DescriptionScorer] |
| 32 | +) |
25 | 33 |
|
26 | | -SCORING_MODULES_MULTILABEL: dict[str, type[ScoringModule]] = { |
27 | | - "knn": KNNScorer, |
28 | | - "linear": LinearScorer, |
29 | | - "mlknn": MLKnnScorer, |
30 | | -} |
| 34 | +PREDICTION_MODULES_MULTICLASS: dict[str, type[Module]] = create_modules_dict( |
| 35 | + [ArgmaxPredictor, JinoosPredictor, ThresholdPredictor, TunablePredictor] |
| 36 | +) |
31 | 37 |
|
32 | | -PREDICTION_MODULES_MULTICLASS: dict[str, type[Module]] = { |
33 | | - "argmax": ArgmaxPredictor, |
34 | | - "jinoos": JinoosPredictor, |
35 | | - "threshold": ThresholdPredictor, |
36 | | - "tunable": TunablePredictor, |
37 | | -} |
| 38 | +PREDICTION_MODULES_MULTILABEL: dict[str, type[Module]] = create_modules_dict([ThresholdPredictor, TunablePredictor]) |
38 | 39 |
|
39 | | -PREDICTION_MODULES_MULTILABEL: dict[str, type[Module]] = { |
40 | | - "threshold": ThresholdPredictor, |
41 | | - "tunable": TunablePredictor, |
42 | | -} |
43 | 40 | __all__ = [ |
44 | 41 | "Module", |
45 | 42 | "ArgmaxPredictor", |
|
0 commit comments