Skip to content

Commit 3cf07c4

Browse files
authored
fix: error in handling device for tensors (#61)
fixes #60
1 parent 440d641 commit 3cf07c4

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

src/ragas/evaluation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def evaluate(
8585

8686
scores = []
8787
for metric in metrics:
88+
print(f"evaluating with [{metric.name}]")
8889
scores.append(metric.score(dataset).select_columns(metric.name))
8990

9091
# log the evaluation event

src/ragas/metrics/answer_relevance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def predict(
129129

130130
for _, data in enumerate(dataloader):
131131
inputs, labels = data
132+
inputs = {k: v.to(self._target_device) for k, v in inputs.items()}
133+
labels = labels.to(self._target_device)
132134
with torch.no_grad():
133135
logits = self.model(**inputs, output_hidden_states=False).logits
134136
loss = self.get_loss(logits, labels)

tests/benchmarks/benchmark_eval.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22

3-
from datasets import Dataset
3+
from datasets import Dataset, load_dataset
44
from torch.cuda import is_available
55

66
from ragas import evaluate
@@ -9,9 +9,13 @@
99
DEVICE = "cuda" if is_available() else "cpu"
1010

1111
PATH_TO_DATSET_GIT_REPO = "../../../datasets/fiqa/"
12-
assert os.path.isdir(PATH_TO_DATSET_GIT_REPO), "Dataset not found"
13-
ds = Dataset.from_json(os.path.join(PATH_TO_DATSET_GIT_REPO, "gen_ds.json"))
14-
assert isinstance(ds, Dataset)
12+
dataset_dir = os.environ.get("DATASET_DIR", PATH_TO_DATSET_GIT_REPO)
13+
if os.path.isdir(dataset_dir):
14+
ds = Dataset.from_csv(os.path.join(dataset_dir, "baseline.csv"))
15+
assert isinstance(ds, Dataset)
16+
else:
17+
# data
18+
ds = load_dataset("explodinggradients/fiqa", "ragas_eval")["baseline"]
1519

1620
if __name__ == "__main__":
1721
result = evaluate(

0 commit comments

Comments
 (0)