Skip to content

Commit 78d3bea

Browse files
authored
feat: agent goal accuracy metric (#1303)
1 parent 0ef19e6 commit 78d3bea

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
from dataclasses import dataclass, field
5+
6+
from pydantic import BaseModel, Field
7+
8+
from ragas.dataset_schema import MultiTurnSample
9+
from ragas.experimental.llms.prompt import PydanticPrompt
10+
from ragas.metrics.base import MetricType, MetricWithLLM, MultiTurnMetric
11+
12+
if t.TYPE_CHECKING:
13+
from langchain_core.callbacks.base import Callbacks
14+
15+
16+
class WorkflowOutput(BaseModel):
17+
user_goal: str = Field(
18+
..., description="The task or objective the user wants to achieve."
19+
)
20+
end_state: str = Field(
21+
..., description="The final outcome or result of the workflow."
22+
)
23+
24+
25+
class CompareOutcomeInput(BaseModel):
26+
desired_outcome: str = Field(
27+
..., description="The desired outcome or result of the workflow."
28+
)
29+
arrived_outcome: str = Field(
30+
..., description="The actual outcome or result of the workflow."
31+
)
32+
33+
34+
class CompareOutcomeOutput(BaseModel):
35+
reason: str = Field(
36+
..., description="The task or objective the user wants to achieve."
37+
)
38+
verdict: t.Literal["0", "1"] = Field(
39+
..., description="The final outcome or result of the workflow."
40+
)
41+
42+
43+
class WorkflowInput(BaseModel):
44+
workflow: str = Field(
45+
..., description="The agentic workflow comprised of Human, AI and Tools."
46+
)
47+
48+
49+
class InferGoalOutcomePrompt(PydanticPrompt[WorkflowInput, WorkflowOutput]):
50+
instruction = "Given an agentic workflow comprised of Human, AI and Tools, identify the user_goal (the task or objective the user wants to achieve) and the end_state (the final outcome or result of the workflow)."
51+
input_model = WorkflowInput
52+
output_model = WorkflowOutput
53+
examples = [
54+
(
55+
WorkflowInput(
56+
workflow="""
57+
Human: Hey, book a table at the nearest best Chinese restaurant for 8:00pm
58+
AI: Sure, let me find the best options for you.
59+
Tools:
60+
restaurant_search: {'cuisine': 'Chinese', 'time': '8:00pm'}
61+
ToolOutput: Found a few options: 1. Golden Dragon, 2. Jade Palace
62+
AI: I found some great options: Golden Dragon and Jade Palace. Which one would you prefer?
63+
Human: Let's go with Golden Dragon.
64+
AI: Great choice! I'll book a table for 8:00pm at Golden Dragon.
65+
Tools:
66+
restaurant_book: {'name': 'Golden Dragon', 'time': '8:00pm'}
67+
ToolOutput: Table booked at Golden Dragon for 8:00pm.
68+
AI: Your table at Golden Dragon is booked for 8:00pm. Enjoy your meal!
69+
Human: thanks
70+
"""
71+
),
72+
WorkflowOutput(
73+
user_goal="Book a table at the nearest best Chinese restaurant for 8:00pm.",
74+
end_state="A table is successfully booked at Golden Dragon (Chinese restaurant) for 8:00pm.",
75+
),
76+
)
77+
]
78+
79+
80+
class CompareOutcomePrompt(PydanticPrompt[CompareOutcomeInput, CompareOutcomeOutput]):
81+
instruction = "Given user goal, desired outcome and acheived outcome compare them and identify if they are the same (1) or different(0)."
82+
input_model = CompareOutcomeInput
83+
output_model = CompareOutcomeOutput
84+
examples = [
85+
(
86+
CompareOutcomeInput(
87+
desired_outcome="A table is successfully booked at any Chinese restaurant for 8:00pm.",
88+
arrived_outcome="A table is successfully booked at Jade Palace (Chinese restaurant) for 8:00pm.",
89+
),
90+
CompareOutcomeOutput(
91+
reason="The arrived outcome is same as the desired outcome and aligns with the user goal.",
92+
verdict="1",
93+
),
94+
)
95+
]
96+
97+
98+
@dataclass
99+
class AgentGoalAccuracyWithReference(MetricWithLLM, MultiTurnMetric):
100+
name: str = "agent_goal_accuracy" # type: ignore
101+
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
102+
default_factory=lambda: {
103+
MetricType.MULTI_TURN: {
104+
"user_input",
105+
"reference",
106+
}
107+
}
108+
)
109+
workflow_prompt: PydanticPrompt = field(
110+
default_factory=lambda: InferGoalOutcomePrompt()
111+
)
112+
compare_outcome_prompt: PydanticPrompt = field(
113+
default_factory=lambda: CompareOutcomePrompt()
114+
)
115+
max_retries: int = 1
116+
117+
async def _multi_turn_ascore(
118+
self,
119+
sample: MultiTurnSample,
120+
callbacks: Callbacks,
121+
) -> float:
122+
assert self.llm is not None, "LLM is not set"
123+
assert sample.reference is not None, "Reference is not set"
124+
125+
prompt_input = WorkflowInput(workflow=sample.pretty_repr())
126+
response = await self.workflow_prompt.generate(
127+
data=prompt_input, llm=self.llm, callbacks=callbacks
128+
)
129+
prompt_input = CompareOutcomeInput(
130+
desired_outcome=sample.reference, arrived_outcome=response.end_state
131+
)
132+
response = await self.compare_outcome_prompt.generate(
133+
data=prompt_input, llm=self.llm, callbacks=callbacks
134+
)
135+
return float(response.verdict)
136+
137+
138+
@dataclass
139+
class AgentGoalAccuracyWithoutReference(MetricWithLLM, MultiTurnMetric):
140+
name: str = "agent_goal_accuracy" # type: ignore
141+
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
142+
default_factory=lambda: {
143+
MetricType.MULTI_TURN: {
144+
"user_input",
145+
}
146+
}
147+
)
148+
workflow_prompt: PydanticPrompt = field(
149+
default_factory=lambda: InferGoalOutcomePrompt()
150+
)
151+
compare_outcome_prompt: PydanticPrompt = field(
152+
default_factory=lambda: CompareOutcomePrompt()
153+
)
154+
max_retries: int = 1
155+
156+
async def _multi_turn_ascore(
157+
self,
158+
sample: MultiTurnSample,
159+
callbacks: Callbacks,
160+
) -> float:
161+
assert self.llm is not None, "LLM is not set"
162+
163+
prompt_input = WorkflowInput(workflow=sample.pretty_repr())
164+
response = await self.workflow_prompt.generate(
165+
data=prompt_input, llm=self.llm, callbacks=callbacks
166+
)
167+
prompt_input = CompareOutcomeInput(
168+
desired_outcome=response.user_goal, arrived_outcome=response.end_state
169+
)
170+
response = await self.compare_outcome_prompt.generate(
171+
data=prompt_input, llm=self.llm, callbacks=callbacks
172+
)
173+
return float(response.verdict)

0 commit comments

Comments
 (0)