Skip to content

Commit 0ef19e6

Browse files
shahules786jjmachanruanhao
authored
feat: Instance specifc rubrics metrics (#1304)
Co-authored-by: Jithin James <[email protected]> Co-authored-by: Hao Ruan <[email protected]>
1 parent e9fb710 commit 0ef19e6

File tree

2 files changed

+167
-0
lines changed

2 files changed

+167
-0
lines changed

src/ragas/dataset_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class SingleTurnSample(BaseEvalSample):
3030
class MultiTurnSample(BaseEvalSample):
3131
user_input: t.List[t.Union[HumanMessage, AIMessage, ToolMessage]]
3232
reference: t.Optional[str] = None
33+
rubrics: t.Optional[t.Dict[str, str]] = None
3334

3435
@validator("user_input")
3536
def validate_messages(cls, messages):
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
from dataclasses import dataclass, field
5+
6+
from ragas.dataset_schema import MultiTurnSample, SingleTurnSample
7+
from ragas.experimental.llms.prompt import PydanticPrompt
8+
from ragas.metrics._domain_specific_rubrics import (
9+
MultiTurnWithoutReferenceInput,
10+
MultiTurnWithoutReferencePrompt,
11+
MultiTurnWithReferenceInput,
12+
SingleTurnWithoutReferenceInput,
13+
SingleTurnWithoutReferencePrompt,
14+
SingleTurnWithReferenceInput,
15+
SingleTurnWithReferencePrompt,
16+
)
17+
from ragas.metrics.base import (
18+
MetricType,
19+
MetricWithLLM,
20+
MultiTurnMetric,
21+
SingleTurnMetric,
22+
)
23+
24+
if t.TYPE_CHECKING:
25+
from langchain_core.callbacks import Callbacks
26+
27+
28+
@dataclass
29+
class InstanceRubricsWithReference(MetricWithLLM, SingleTurnMetric, MultiTurnMetric):
30+
name: str = "labelled_rubrics_score" # type: ignore
31+
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
32+
default_factory=lambda: {
33+
MetricType.SINGLE_TURN: {"user_input", "response", "reference", "rubrics"},
34+
MetricType.MULTI_TURN: {"user_input", "reference", "rubrics"},
35+
}
36+
)
37+
single_turn_prompt: PydanticPrompt = field(
38+
default_factory=lambda: SingleTurnWithReferencePrompt()
39+
)
40+
multi_turn_prompt: PydanticPrompt = field(
41+
default_factory=lambda: MultiTurnWithoutReferencePrompt()
42+
)
43+
44+
max_retries: int = 1
45+
46+
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
47+
assert self.llm is not None, "LLM is not set"
48+
49+
user_input, contexts, response, reference, rubrics = (
50+
row["user_input"],
51+
row.get("retrieved_contexts"),
52+
row["response"],
53+
row["reference"],
54+
row["rubrics"],
55+
)
56+
if contexts is not None:
57+
contexts = "\n".join(contexts)
58+
user_input = f"{user_input} answer using context: {contexts}"
59+
60+
prompt_input = SingleTurnWithReferenceInput(
61+
user_input=user_input,
62+
response=response,
63+
reference=reference,
64+
rubrics=rubrics,
65+
)
66+
67+
response = await self.single_turn_prompt.generate(
68+
data=prompt_input, llm=self.llm, callbacks=callbacks
69+
)
70+
return response.score
71+
72+
async def _single_turn_ascore(
73+
self, sample: SingleTurnSample, callbacks: Callbacks
74+
) -> float:
75+
row = sample.dict()
76+
return await self._ascore(row, callbacks)
77+
78+
async def _multi_turn_ascore(
79+
self, sample: MultiTurnSample, callbacks: Callbacks
80+
) -> float:
81+
assert self.llm is not None, "LLM is not set"
82+
assert sample.rubrics is not None, "Rubrics are not set"
83+
assert sample.reference is not None, "Reference is not set"
84+
85+
interaction = sample.pretty_repr()
86+
reference = sample.reference
87+
rubrics = sample.rubrics
88+
prompt_input = MultiTurnWithReferenceInput(
89+
user_input=interaction,
90+
reference=reference,
91+
rubrics=rubrics,
92+
)
93+
output = await self.multi_turn_prompt.generate(
94+
data=prompt_input,
95+
llm=self.llm,
96+
callbacks=callbacks,
97+
)
98+
return output.score
99+
100+
101+
@dataclass
102+
class InstanceRubricsScoreWithoutReference(
103+
MetricWithLLM, SingleTurnMetric, MultiTurnMetric
104+
):
105+
name: str = "reference_free_rubrics_score" # type: ignore
106+
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
107+
default_factory=lambda: {
108+
MetricType.SINGLE_TURN: {"user_input", "response", "rubrics"},
109+
MetricType.MULTI_TURN: {"user_input", "rubrics"},
110+
}
111+
)
112+
single_turn_prompt: PydanticPrompt = field(
113+
default_factory=lambda: SingleTurnWithoutReferencePrompt()
114+
)
115+
multi_turn_prompt: PydanticPrompt = field(
116+
default_factory=lambda: MultiTurnWithoutReferencePrompt()
117+
)
118+
max_retries: int = 1
119+
120+
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
121+
assert self.llm is not None, "LLM is not set"
122+
123+
user_input, contexts, response, rubrics = (
124+
row["user_input"],
125+
row.get("retrieved_contexts"),
126+
row["response"],
127+
row["rubrics"],
128+
)
129+
if contexts is not None:
130+
contexts = "\n".join(contexts)
131+
user_input = f"{user_input} answer using context: {contexts}"
132+
133+
prompt_input = SingleTurnWithoutReferenceInput(
134+
user_input=user_input,
135+
response=response,
136+
rubrics=rubrics,
137+
)
138+
139+
response = await self.single_turn_prompt.generate(
140+
data=prompt_input, llm=self.llm, callbacks=callbacks
141+
)
142+
return response.score
143+
144+
async def _single_turn_ascore(
145+
self, sample: SingleTurnSample, callbacks: Callbacks
146+
) -> float:
147+
row = sample.dict()
148+
return await self._ascore(row, callbacks)
149+
150+
async def _multi_turn_ascore(
151+
self, sample: MultiTurnSample, callbacks: Callbacks
152+
) -> float:
153+
assert self.llm is not None, "LLM is not set"
154+
assert sample.rubrics is not None, "Rubrics are not set"
155+
interaction = sample.pretty_repr()
156+
rubrics = sample.rubrics
157+
prompt_input = MultiTurnWithoutReferenceInput(
158+
user_input=interaction,
159+
rubrics=rubrics,
160+
)
161+
output = await self.multi_turn_prompt.generate(
162+
data=prompt_input,
163+
llm=self.llm,
164+
callbacks=callbacks,
165+
)
166+
return output.score

0 commit comments

Comments
 (0)