|
4 | 4 | C - contexts: context used for generation |
5 | 5 | G - ground_truth: ground truth answer |
6 | 6 | """ |
| 7 | + |
7 | 8 | from __future__ import annotations |
8 | 9 |
|
9 | 10 | import asyncio |
|
25 | 26 | EvaluationMode = Enum("EvaluationMode", "qac qa qc gc ga qga qcg") |
26 | 27 |
|
27 | 28 |
|
| 29 | +def get_required_columns( |
| 30 | + eval_mod: EvaluationMode, ignore_columns: t.Optional[t.List[str]] = None |
| 31 | +) -> t.List[str]: |
| 32 | + if eval_mod == EvaluationMode.qac: |
| 33 | + keys = ["question", "answer", "contexts"] |
| 34 | + elif eval_mod == EvaluationMode.qa: |
| 35 | + keys = ["question", "answer"] |
| 36 | + elif eval_mod == EvaluationMode.qc: |
| 37 | + keys = ["question", "contexts"] |
| 38 | + elif eval_mod == EvaluationMode.gc: |
| 39 | + keys = ["contexts", "ground_truth"] |
| 40 | + elif eval_mod == EvaluationMode.ga: |
| 41 | + keys = ["answer", "ground_truth"] |
| 42 | + elif eval_mod == EvaluationMode.qga: |
| 43 | + keys = ["question", "contexts", "answer", "ground_truth"] |
| 44 | + elif eval_mod == EvaluationMode.qcg: |
| 45 | + keys = ["question", "contexts", "ground_truth"] |
| 46 | + ignore_columns = ignore_columns or [] |
| 47 | + |
| 48 | + return [k for k in keys if k not in ignore_columns] |
| 49 | + |
| 50 | + |
28 | 51 | @dataclass |
29 | 52 | class Metric(ABC): |
30 | 53 | @property |
31 | 54 | @abstractmethod |
32 | | - def name(self) -> str: |
33 | | - ... |
| 55 | + def name(self) -> str: ... |
34 | 56 |
|
35 | 57 | @property |
36 | 58 | @abstractmethod |
37 | | - def evaluation_mode(self) -> EvaluationMode: |
38 | | - ... |
| 59 | + def evaluation_mode(self) -> EvaluationMode: ... |
39 | 60 |
|
40 | 61 | @abstractmethod |
41 | 62 | def init(self, run_config: RunConfig): |
@@ -97,8 +118,9 @@ async def ascore( |
97 | 118 | return score |
98 | 119 |
|
99 | 120 | @abstractmethod |
100 | | - async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float: |
101 | | - ... |
| 121 | + async def _ascore( |
| 122 | + self, row: t.Dict, callbacks: Callbacks, is_async: bool |
| 123 | + ) -> float: ... |
102 | 124 |
|
103 | 125 |
|
104 | 126 | @dataclass |
|
0 commit comments