Skip to content

Commit 135612d

Browse files
authored
feat: add validation step (#72)
added a validation step that checks - if the columns corresponding to the metrics passed are present - if the columns used are of the correct datatype
1 parent e5fa2de commit 135612d

File tree

9 files changed

+143
-26
lines changed

9 files changed

+143
-26
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ run-benchmarks: ## Run benchmarks
3030
@cd $(GIT_ROOT)/tests/benchmarks && python benchmark_eval.py
3131
test: ## Run tests
3232
@echo "Running tests..."
33-
@pytest tests/unit
33+
@pytest tests/unit $(shell if [ -n "$(k)" ]; then echo "-k $(k)"; fi)

src/ragas/evaluation.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,14 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, field
4-
from enum import Enum
54

65
import numpy as np
76
from datasets import Dataset, concatenate_datasets
87

98
from ragas._analytics import EvaluationEvent, track
109
from ragas.metrics.base import Metric
1110
from ragas.metrics.critique import AspectCritique
12-
13-
EvaluationMode = Enum("EvaluationMode", "generative retrieval grounded")
14-
15-
16-
def get_evaluation_mode(ds: Dataset):
17-
"""
18-
validates the dataset and returns the evaluation type
19-
20-
possible evaluation types
21-
1. (q,a,c)
22-
2. (q,a)
23-
3. (q,c)
24-
4. (g,a)
25-
"""
26-
...
11+
from ragas.validation import validate_column_dtypes, validate_evaluation_modes
2712

2813

2914
def evaluate(
@@ -70,16 +55,15 @@ def evaluate(
7055
if dataset is None:
7156
raise ValueError("Provide dataset!")
7257

73-
# TODO: validate EvaluationMode here
74-
# evaluation_mode = get_evaluation_mode(dataset)
75-
76-
# TODO: check if all the metrics are compatible with the evaluation mode
77-
7858
if metrics is None:
7959
from ragas.metrics import answer_relevancy, context_relevancy, faithfulness
8060

8161
metrics = [answer_relevancy, context_relevancy, faithfulness]
8262

63+
# validation
64+
validate_evaluation_modes(dataset, metrics)
65+
validate_column_dtypes(dataset)
66+
8367
# run the evaluation on dataset with different metrics
8468
# initialize all the models in the metrics
8569
[m.init_model() for m in metrics]

src/ragas/metrics/answer_relevance.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from transformers import AutoConfig, AutoTokenizer
1616
from transformers.models.auto.modeling_auto import MODEL_WITH_LM_HEAD_MAPPING_NAMES
1717

18-
from ragas.metrics.base import Metric
18+
from ragas.metrics.base import EvaluationMode, Metric
1919

2020
if t.TYPE_CHECKING:
2121
import numpy.typing as npt
@@ -142,6 +142,7 @@ def predict(
142142
@dataclass
143143
class AnswerRelevancy(Metric):
144144
name: str = "answer_relevancy"
145+
evaluation_mode: EvaluationMode = EvaluationMode.qa
145146
batch_size: int = 32
146147
model_name: str = "t5-base"
147148

src/ragas/metrics/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import typing as t
1010
from abc import ABC, abstractmethod
1111
from dataclasses import dataclass, field
12+
from enum import Enum
1213
from math import floor
1314

1415
from datasets import Dataset
@@ -32,6 +33,9 @@ def make_batches(total_size: int, batch_size: int) -> list[range]:
3233
return batches
3334

3435

36+
EvaluationMode = Enum("EvaluationMode", "qac qa qc ga")
37+
38+
3539
@dataclass
3640
class Metric(ABC):
3741
batch_size: int
@@ -41,6 +45,11 @@ class Metric(ABC):
4145
def name(self) -> str:
4246
...
4347

48+
@property
49+
@abstractmethod
50+
def evaluation_mode(self) -> EvaluationMode:
51+
...
52+
4453
@abstractmethod
4554
def init_model():
4655
"""

src/ragas/metrics/context_relevance.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sentence_transformers import CrossEncoder
1212
from tqdm import tqdm
1313

14-
from ragas.metrics.base import MetricWithLLM
14+
from ragas.metrics.base import EvaluationMode, MetricWithLLM
1515
from ragas.metrics.llms import generate
1616

1717
CONTEXT_RELEVANCE = HumanMessagePromptTemplate.from_template(
@@ -105,6 +105,7 @@ class ContextRelevancy(MetricWithLLM):
105105
"""
106106

107107
name: str = "context_relavency"
108+
evaluation_mode: EvaluationMode = EvaluationMode.qc
108109
batch_size: int = 15
109110
strictness: int = 2
110111
agreement_metric: str = "bert_score"

src/ragas/metrics/critique.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
1111
from tqdm import tqdm
1212

13-
from ragas.metrics.base import MetricWithLLM, _llm_factory
13+
from ragas.metrics.base import EvaluationMode, MetricWithLLM, _llm_factory
1414
from ragas.metrics.llms import generate
1515

1616
CRITIQUE_PROMPT = HumanMessagePromptTemplate.from_template(
@@ -53,6 +53,7 @@ class AspectCritique(MetricWithLLM):
5353
"""
5454

5555
name: str = field(default="", repr=True)
56+
evaluation_mode: EvaluationMode = EvaluationMode.qac
5657
definition: str = field(default="", repr=True)
5758
strictness: int = field(default=1, repr=False)
5859
batch_size: int = field(default=15, repr=False)

src/ragas/metrics/faithfulnes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
88
from tqdm import tqdm
99

10-
from ragas.metrics.base import MetricWithLLM
10+
from ragas.metrics.base import EvaluationMode, MetricWithLLM
1111
from ragas.metrics.llms import generate
1212

1313
if t.TYPE_CHECKING:
@@ -65,6 +65,7 @@
6565
@dataclass
6666
class Faithfulness(MetricWithLLM):
6767
name: str = "faithfulness"
68+
evaluation_mode: EvaluationMode = EvaluationMode.qac
6869
batch_size: int = 15
6970

7071
def init_model(self: t.Self):

src/ragas/validation.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from __future__ import annotations
2+
3+
from datasets import Dataset, Sequence
4+
5+
from ragas.metrics.base import EvaluationMode, Metric
6+
7+
8+
def validate_column_dtypes(ds: Dataset):
9+
for column_names in ["question", "answer"]:
10+
if column_names in ds.features:
11+
if ds.features[column_names].dtype != "string":
12+
raise ValueError(
13+
f'Dataset feature "{column_names}" should be of type string'
14+
)
15+
16+
for column_names in ["contexts", "ground_truths"]:
17+
if column_names in ds.features:
18+
if not (
19+
isinstance(ds.features[column_names], Sequence)
20+
and ds.features[column_names].feature.dtype == "string"
21+
):
22+
raise ValueError(
23+
f'Dataset feature "{column_names}" should be of type'
24+
" Sequence[string]"
25+
)
26+
27+
28+
EVALMODE_TO_COLUMNS = {
29+
EvaluationMode.qac: ["question", "answer", "contexts"],
30+
EvaluationMode.qa: ["question", "answer"],
31+
EvaluationMode.qc: ["question", "contexts"],
32+
EvaluationMode.ga: ["ground_truths", "answer"],
33+
}
34+
35+
36+
def validate_evaluation_modes(ds: Dataset, metrics: list[Metric]):
37+
"""
38+
validates the dataset and returns the evaluation type
39+
40+
possible evaluation types
41+
1. (q,a,c)
42+
2. (q,a)
43+
3. (q,c)
44+
4. (g,a)
45+
"""
46+
47+
for m in metrics:
48+
required_columns = set(EVALMODE_TO_COLUMNS[m.evaluation_mode])
49+
available_columns = set(ds.features.keys())
50+
if required_columns.symmetric_difference(available_columns):
51+
raise ValueError(
52+
f"The metric [{m.name}] that that is used requires the following "
53+
f"additional columns {list(required_columns - available_columns)} "
54+
"to be present in the dataset."
55+
)

tests/unit/test_validation.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from collections import namedtuple
2+
3+
import pytest
4+
from datasets import Dataset
5+
6+
from ragas.metrics import answer_relevancy, context_relevancy, faithfulness
7+
from ragas.validation import validate_column_dtypes, validate_evaluation_modes
8+
9+
CaseToTest = namedtuple(
10+
"TestCase", ["q", "a", "c", "g", "is_valid_columns", "metrics", "is_valid_metrics"]
11+
)
12+
13+
TEST_CASES = [
14+
CaseToTest("a", "b", ["c"], None, True, [faithfulness], True),
15+
CaseToTest("a", "b", ["c"], ["g"], True, [faithfulness], False),
16+
CaseToTest("a", None, ["c"], None, True, [context_relevancy], True),
17+
CaseToTest("a", None, "c", None, False, [context_relevancy], True),
18+
CaseToTest(
19+
"a", None, [["c"]], None, False, [context_relevancy, answer_relevancy], False
20+
),
21+
CaseToTest("a", None, ["c"], "g", False, [context_relevancy], False),
22+
CaseToTest("a", None, ["c"], [["g"]], False, [context_relevancy], False),
23+
CaseToTest(1, None, ["c"], ["g"], False, [context_relevancy], False),
24+
CaseToTest(1, None, None, None, False, [context_relevancy], False),
25+
]
26+
27+
28+
@pytest.mark.parametrize("testcase", TEST_CASES)
29+
def test_validate_column_dtypes(testcase):
30+
dataset_dict = {}
31+
if testcase.q is not None:
32+
dataset_dict["question"] = [testcase.q]
33+
if testcase.a is not None:
34+
dataset_dict["answer"] = [testcase.a]
35+
if testcase.c is not None:
36+
dataset_dict["contexts"] = [testcase.c]
37+
if testcase.g is not None:
38+
dataset_dict["ground_truths"] = [testcase.g]
39+
40+
test_dataset = Dataset.from_dict(dataset_dict)
41+
if testcase.is_valid_columns:
42+
validate_column_dtypes(test_dataset)
43+
else:
44+
with pytest.raises(ValueError):
45+
validate_column_dtypes(test_dataset)
46+
47+
48+
@pytest.mark.parametrize("testcase", TEST_CASES)
49+
def test_validate_columns_and_metrics(testcase):
50+
dataset_dict = {}
51+
if testcase.q is not None:
52+
dataset_dict["question"] = [testcase.q]
53+
if testcase.a is not None:
54+
dataset_dict["answer"] = [testcase.a]
55+
if testcase.c is not None:
56+
dataset_dict["contexts"] = [testcase.c]
57+
if testcase.g is not None:
58+
dataset_dict["ground_truths"] = [testcase.g]
59+
test_dataset = Dataset.from_dict(dataset_dict)
60+
61+
if testcase.is_valid_metrics:
62+
validate_evaluation_modes(test_dataset, testcase.metrics)
63+
else:
64+
with pytest.raises(ValueError):
65+
validate_evaluation_modes(test_dataset, testcase.metrics)

0 commit comments

Comments
 (0)