File tree Expand file tree Collapse file tree 3 files changed +11
-4
lines changed Expand file tree Collapse file tree 3 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -85,6 +85,7 @@ def evaluate(
8585
8686 scores = []
8787 for metric in metrics :
88+ print (f"evaluating with [{ metric .name } ]" )
8889 scores .append (metric .score (dataset ).select_columns (metric .name ))
8990
9091 # log the evaluation event
Original file line number Diff line number Diff line change @@ -129,6 +129,8 @@ def predict(
129129
130130 for _ , data in enumerate (dataloader ):
131131 inputs , labels = data
132+ inputs = {k : v .to (self ._target_device ) for k , v in inputs .items ()}
133+ labels = labels .to (self ._target_device )
132134 with torch .no_grad ():
133135 logits = self .model (** inputs , output_hidden_states = False ).logits
134136 loss = self .get_loss (logits , labels )
Original file line number Diff line number Diff line change 11import os
22
3- from datasets import Dataset
3+ from datasets import Dataset , load_dataset
44from torch .cuda import is_available
55
66from ragas import evaluate
99DEVICE = "cuda" if is_available () else "cpu"
1010
1111PATH_TO_DATSET_GIT_REPO = "../../../datasets/fiqa/"
12- assert os .path .isdir (PATH_TO_DATSET_GIT_REPO ), "Dataset not found"
13- ds = Dataset .from_json (os .path .join (PATH_TO_DATSET_GIT_REPO , "gen_ds.json" ))
14- assert isinstance (ds , Dataset )
12+ dataset_dir = os .environ .get ("DATASET_DIR" , PATH_TO_DATSET_GIT_REPO )
13+ if os .path .isdir (dataset_dir ):
14+ ds = Dataset .from_csv (os .path .join (dataset_dir , "baseline.csv" ))
15+ assert isinstance (ds , Dataset )
16+ else :
17+ # data
18+ ds = load_dataset ("explodinggradients/fiqa" , "ragas_eval" )["baseline" ]
1519
1620if __name__ == "__main__" :
1721 result = evaluate (
You can’t perform that action at this time.
0 commit comments