-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscoring.py
More file actions
31 lines (23 loc) · 859 Bytes
/
scoring.py
File metadata and controls
31 lines (23 loc) · 859 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from pydantic import BaseModel
from typing import List
from data_loading import Sample
class Scorer(BaseModel):
def run(self, sample: Sample) -> float:
raise NotImplementedError
class StateTransitionAccuracyScorer(Scorer):
def run(self, sample: Sample) -> float:
if sample.pred == "1":
return 1.0
return 0.0
class StateCheckingAccuracyScorer(Scorer):
def run(self, sample: Sample) -> float:
if sample.pred == sample.outputs["current_status"]:
return 1.0
return 0.0
def select_scorer(scorer_name: str) -> Scorer:
if scorer_name == "state_transition_accuracy":
return StateTransitionAccuracyScorer()
elif scorer_name == "state_checking_accuracy":
return StateCheckingAccuracyScorer()
else:
raise ValueError(f"Scorer {scorer_name} not found")