Skip to content

Commit 9bd1402

Browse files
authored
feat: add output type to metrics (#1722)
Added output_type as optional parameter to LLM based metrics to derive the loss required for optimising the metric. This feature can also be later used to change the UI layout when we introduce ranking based metrics.
1 parent eb5f745 commit 9bd1402

21 files changed

+112
-11
lines changed

src/ragas/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from ragas.metrics._topic_adherence import TopicAdherenceScore
6262
from ragas.metrics.base import (
6363
Metric,
64+
MetricOutputType,
6465
MetricType,
6566
MetricWithEmbeddings,
6667
MetricWithLLM,
@@ -76,6 +77,7 @@
7677
"MetricWithLLM",
7778
"SingleTurnMetric",
7879
"MultiTurnMetric",
80+
"MetricOutputType",
7981
# specific metrics
8082
"AnswerCorrectness",
8183
"answer_correctness",

src/ragas/metrics/_answer_correctness.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
LongFormAnswerPrompt,
1616
)
1717
from ragas.metrics.base import (
18+
MetricOutputType,
1819
MetricType,
1920
MetricWithEmbeddings,
2021
MetricWithLLM,
@@ -163,6 +164,7 @@ class AnswerCorrectness(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):
163164
MetricType.SINGLE_TURN: {"user_input", "response", "reference"}
164165
}
165166
)
167+
output_type = MetricOutputType.CONTINUOUS
166168
correctness_prompt: PydanticPrompt = field(default_factory=CorrectnessClassifier)
167169
long_form_answer_prompt: PydanticPrompt = field(
168170
default_factory=LongFormAnswerPrompt

src/ragas/metrics/_answer_relevance.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ragas.dataset_schema import SingleTurnSample
1111
from ragas.metrics.base import (
12+
MetricOutputType,
1213
MetricType,
1314
MetricWithEmbeddings,
1415
MetricWithLLM,
@@ -87,6 +88,8 @@ class ResponseRelevancy(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):
8788
}
8889
}
8990
)
91+
output_type = MetricOutputType.CONTINUOUS
92+
9093
question_generation: PydanticPrompt = ResponseRelevancePrompt()
9194
strictness: int = 3
9295

src/ragas/metrics/_answer_similarity.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
from ragas.dataset_schema import SingleTurnSample
1010
from ragas.embeddings.base import HuggingfaceEmbeddings
11-
from ragas.metrics.base import MetricType, MetricWithEmbeddings, SingleTurnMetric
11+
from ragas.metrics.base import (
12+
MetricOutputType,
13+
MetricType,
14+
MetricWithEmbeddings,
15+
SingleTurnMetric,
16+
)
1217

1318
if t.TYPE_CHECKING:
1419
from langchain_core.callbacks.base import Callbacks
@@ -41,6 +46,7 @@ class SemanticSimilarity(MetricWithEmbeddings, SingleTurnMetric):
4146
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
4247
default_factory=lambda: {MetricType.SINGLE_TURN: {"reference", "response"}}
4348
)
49+
output_type = MetricOutputType.CONTINUOUS
4450
is_cross_encoder: bool = False
4551
threshold: t.Optional[float] = None
4652

src/ragas/metrics/_aspect_critic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ragas.dataset_schema import MultiTurnSample, SingleTurnSample
1010
from ragas.metrics.base import (
11+
MetricOutputType,
1112
MetricType,
1213
MetricWithLLM,
1314
MultiTurnMetric,
@@ -94,6 +95,7 @@ def __init__(
9495
definition: str,
9596
llm: t.Optional[BaseRagasLLM] = None,
9697
required_columns: t.Optional[t.Dict[MetricType, t.Set[str]]] = None,
98+
output_type: t.Optional[MetricOutputType] = MetricOutputType.BINARY,
9799
single_turn_prompt: t.Optional[PydanticPrompt] = None,
98100
multi_turn_prompt: t.Optional[PydanticPrompt] = None,
99101
strictness: int = 1,
@@ -116,6 +118,7 @@ def __init__(
116118
name=name,
117119
_required_columns=self._required_columns,
118120
llm=llm,
121+
output_type=output_type,
119122
)
120123

121124
self._definition = definition

src/ragas/metrics/_context_entities_recall.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from pydantic import BaseModel
99

1010
from ragas.dataset_schema import SingleTurnSample
11-
from ragas.metrics.base import MetricType, MetricWithLLM, SingleTurnMetric
11+
from ragas.metrics.base import (
12+
MetricOutputType,
13+
MetricType,
14+
MetricWithLLM,
15+
SingleTurnMetric,
16+
)
1217
from ragas.prompt import PydanticPrompt, StringIO
1318

1419
if t.TYPE_CHECKING:
@@ -113,6 +118,7 @@ class ContextEntityRecall(MetricWithLLM, SingleTurnMetric):
113118
MetricType.SINGLE_TURN: {"reference", "retrieved_contexts"}
114119
}
115120
)
121+
output_type = MetricOutputType.CONTINUOUS
116122
context_entity_recall_prompt: PydanticPrompt = field(
117123
default_factory=ExtractEntitiesPrompt
118124
)

src/ragas/metrics/_context_precision.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99

1010
from ragas.dataset_schema import SingleTurnSample
1111
from ragas.metrics._string import NonLLMStringSimilarity
12-
from ragas.metrics.base import MetricType, MetricWithLLM, SingleTurnMetric, ensembler
12+
from ragas.metrics.base import (
13+
MetricOutputType,
14+
MetricType,
15+
MetricWithLLM,
16+
SingleTurnMetric,
17+
ensembler,
18+
)
1319
from ragas.prompt import PydanticPrompt
1420
from ragas.run_config import RunConfig
1521
from ragas.utils import deprecated
@@ -98,6 +104,7 @@ class LLMContextPrecisionWithReference(MetricWithLLM, SingleTurnMetric):
98104
}
99105
}
100106
)
107+
output_type = MetricOutputType.CONTINUOUS
101108
context_precision_prompt: PydanticPrompt = field(
102109
default_factory=ContextPrecisionPrompt
103110
)

src/ragas/metrics/_context_recall.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99

1010
from ragas.dataset_schema import SingleTurnSample
1111
from ragas.metrics._string import NonLLMStringSimilarity
12-
from ragas.metrics.base import MetricType, MetricWithLLM, SingleTurnMetric, ensembler
12+
from ragas.metrics.base import (
13+
MetricOutputType,
14+
MetricType,
15+
MetricWithLLM,
16+
SingleTurnMetric,
17+
ensembler,
18+
)
1319
from ragas.prompt import PydanticPrompt
1420
from ragas.run_config import RunConfig
1521
from ragas.utils import deprecated
@@ -102,6 +108,7 @@ class LLMContextRecall(MetricWithLLM, SingleTurnMetric):
102108
}
103109
}
104110
)
111+
output_type: t.Optional[MetricOutputType] = MetricOutputType.CONTINUOUS
105112
context_recall_prompt: PydanticPrompt = field(
106113
default_factory=ContextRecallClassificationPrompt
107114
)
@@ -202,6 +209,7 @@ class NonLLMContextRecall(SingleTurnMetric):
202209
}
203210
}
204211
)
212+
output_type: MetricOutputType = MetricOutputType.CONTINUOUS
205213
distance_measure: SingleTurnMetric = field(
206214
default_factory=lambda: NonLLMStringSimilarity()
207215
)

src/ragas/metrics/_domain_specific_rubrics.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from ragas.dataset_schema import MultiTurnSample, SingleTurnSample
99
from ragas.metrics.base import (
10+
MetricOutputType,
1011
MetricType,
1112
MetricWithLLM,
1213
MultiTurnMetric,
@@ -88,6 +89,7 @@ def __init__(
8889
rubrics: t.Dict[str, str] = DEFAULT_REFERENCE_FREE_RUBRICS,
8990
llm: t.Optional[BaseRagasLLM] = None,
9091
required_columns: t.Optional[t.Dict[MetricType, t.Set[str]]] = None,
92+
output_type: t.Optional[MetricOutputType] = MetricOutputType.DISCRETE,
9193
single_turn_prompt: t.Optional[PydanticPrompt] = None,
9294
multi_turn_prompt: t.Optional[PydanticPrompt] = None,
9395
max_retries: int = 1,
@@ -109,7 +111,12 @@ def __init__(
109111
"reference:optional",
110112
},
111113
}
112-
super().__init__(name=name, llm=llm, _required_columns=self._required_columns)
114+
super().__init__(
115+
name=name,
116+
llm=llm,
117+
_required_columns=self._required_columns,
118+
output_type=output_type,
119+
)
113120

114121
def __repr__(self) -> str:
115122
return f"{self.name}(required_columns={self.required_columns}, llm={self.llm}), rubrics={self.rubrics}"

src/ragas/metrics/_factual_correctness.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
NLIStatementPrompt,
1616
)
1717
from ragas.metrics.base import (
18+
MetricOutputType,
1819
MetricType,
1920
MetricWithLLM,
2021
SingleTurnMetric,
@@ -210,6 +211,7 @@ class FactualCorrectness(MetricWithLLM, SingleTurnMetric):
210211
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
211212
default_factory=lambda: {MetricType.SINGLE_TURN: {"response", "reference"}}
212213
)
214+
output_type: t.Optional[MetricOutputType] = MetricOutputType.CONTINUOUS
213215
mode: t.Literal["precision", "recall", "f1"] = "f1"
214216
beta: float = 1.0
215217
atomicity: t.Literal["low", "high"] = "low"

0 commit comments

Comments
 (0)