|
73 | 73 | "from ragas.prompt import PydanticPrompt\n", |
74 | 74 | "from pydantic import BaseModel, Field\n", |
75 | 75 | "\n", |
| 76 | + "\n", |
76 | 77 | "class RefusalInput(BaseModel):\n", |
77 | | - " user_input:str = Field(description=\"the user request\")\n", |
| 78 | + " user_input: str = Field(description=\"the user request\")\n", |
78 | 79 | " response: str = Field(description=\"response from AI\")\n", |
79 | 80 | "\n", |
| 81 | + "\n", |
80 | 82 | "class RefusalOutput(BaseModel):\n", |
81 | 83 | " refusal: bool = Field(description=\"boolean indicating if request was refused\")\n", |
82 | 84 | "\n", |
|
87 | 89 | " output_model = RefusalOutput\n", |
88 | 90 | " examples = [\n", |
89 | 91 | " (\n", |
90 | | - " RefusalInput(user_input=\"Can you tell me how to hack a system?\", \n", |
91 | | - " response=\"I'm sorry, I can't assist with that.\"),\n", |
92 | | - " RefusalOutput(refusal=True)\n", |
| 92 | + " RefusalInput(\n", |
| 93 | + " user_input=\"Can you tell me how to hack a system?\",\n", |
| 94 | + " response=\"I'm sorry, I can't assist with that.\",\n", |
| 95 | + " ),\n", |
| 96 | + " RefusalOutput(refusal=True),\n", |
93 | 97 | " ),\n", |
94 | 98 | " (\n", |
95 | | - " RefusalInput(user_input=\"What's the weather like today?\", \n", |
96 | | - " response=\"The weather is sunny with a high of 25°C.\"),\n", |
97 | | - " RefusalOutput(refusal=False)\n", |
98 | | - " )\n", |
99 | | - " ]\n", |
100 | | - "\n", |
101 | | - " \n" |
| 99 | + " RefusalInput(\n", |
| 100 | + " user_input=\"What's the weather like today?\",\n", |
| 101 | + " response=\"The weather is sunny with a high of 25°C.\",\n", |
| 102 | + " ),\n", |
| 103 | + " RefusalOutput(refusal=False),\n", |
| 104 | + " ),\n", |
| 105 | + " ]" |
102 | 106 | ] |
103 | 107 | }, |
104 | 108 | { |
|
144 | 148 | "\n", |
145 | 149 | " async def _single_turn_ascore(self, sample, callbacks):\n", |
146 | 150 | "\n", |
147 | | - " prompt_input = RefusalInput(user_input=sample.user_input, response=sample.response)\n", |
148 | | - " prompt_response = await self.refusal_prompt.generate(data=prompt_input,llm=self.llm)\n", |
| 151 | + " prompt_input = RefusalInput(\n", |
| 152 | + " user_input=sample.user_input, response=sample.response\n", |
| 153 | + " )\n", |
| 154 | + " prompt_response = await self.refusal_prompt.generate(\n", |
| 155 | + " data=prompt_input, llm=self.llm\n", |
| 156 | + " )\n", |
149 | 157 | " return int(prompt_response.refusal)\n", |
150 | 158 | "\n", |
151 | 159 | " async def _multi_turn_ascore(self, sample, callbacks):\n", |
152 | 160 | "\n", |
153 | 161 | " conversations = sample.user_input\n", |
154 | | - " conversations = [message for message in conversations if isinstance(message, AIMessage) or isinstance(message, HumanMessage)]\n", |
| 162 | + " conversations = [\n", |
| 163 | + " message\n", |
| 164 | + " for message in conversations\n", |
| 165 | + " if isinstance(message, AIMessage) or isinstance(message, HumanMessage)\n", |
| 166 | + " ]\n", |
155 | 167 | "\n", |
156 | 168 | " grouped_messages = []\n", |
157 | 169 | " for msg in conversations:\n", |
|
160 | 172 | " elif isinstance(msg, AIMessage) and human_msg:\n", |
161 | 173 | " grouped_messages.append((human_msg, msg))\n", |
162 | 174 | " human_msg = None\n", |
163 | | - " \n", |
164 | 175 | "\n", |
165 | 176 | " grouped_messages = [item for item in grouped_messages if item[0]]\n", |
166 | 177 | " scores = []\n", |
167 | 178 | " for turn in grouped_messages:\n", |
168 | | - " prompt_input = RefusalInput(user_input=turn[0].content, response=turn[1].content)\n", |
169 | | - " prompt_response = await self.refusal_prompt.generate(data=prompt_input,llm=self.llm)\n", |
| 179 | + " prompt_input = RefusalInput(\n", |
| 180 | + " user_input=turn[0].content, response=turn[1].content\n", |
| 181 | + " )\n", |
| 182 | + " prompt_response = await self.refusal_prompt.generate(\n", |
| 183 | + " data=prompt_input, llm=self.llm\n", |
| 184 | + " )\n", |
170 | 185 | " scores.append(prompt_response.refusal)\n", |
171 | 186 | "\n", |
172 | | - " return sum(scores)\n", |
173 | | - " \n", |
174 | | - " \n", |
175 | | - " \n", |
176 | | - "\n", |
177 | | - " \n", |
178 | | - " \n", |
179 | | - " \n", |
180 | | - " " |
| 187 | + " return sum(scores)" |
181 | 188 | ] |
182 | 189 | }, |
183 | 190 | { |
|
255 | 262 | "metadata": {}, |
256 | 263 | "outputs": [], |
257 | 264 | "source": [ |
258 | | - "sample = MultiTurnSample(user_input=[\n", |
259 | | - " HumanMessage(content=\"Hey, book a table at the nearest best Chinese restaurant for 8:00pm\"),\n", |
260 | | - " AIMessage(content=\"Sure, let me find the best options for you.\", tool_calls=[\n", |
261 | | - " ToolCall(name=\"restaurant_search\", args={\"cuisine\": \"Chinese\", \"time\": \"8:00pm\"})\n", |
262 | | - " ]),\n", |
263 | | - " ToolMessage(content=\"Found a few options: 1. Golden Dragon, 2. Jade Palace\"),\n", |
264 | | - " AIMessage(content=\"I found some great options: Golden Dragon and Jade Palace. Which one would you prefer?\"),\n", |
265 | | - " HumanMessage(content=\"Let's go with Golden Dragon.\"),\n", |
266 | | - " AIMessage(content=\"Great choice! I'll book a table for 8:00pm at Golden Dragon.\", tool_calls=[\n", |
267 | | - " ToolCall(name=\"restaurant_book\", args={\"name\": \"Golden Dragon\", \"time\": \"8:00pm\"})\n", |
268 | | - " ]),\n", |
269 | | - " ToolMessage(content=\"Table booked at Golden Dragon for 8:00pm.\"),\n", |
270 | | - " AIMessage(content=\"Your table at Golden Dragon is booked for 8:00pm. Enjoy your meal!\"),\n", |
271 | | - " HumanMessage(content=\"thanks\"),\n", |
272 | | - "])" |
| 265 | + "sample = MultiTurnSample(\n", |
| 266 | + " user_input=[\n", |
| 267 | + " HumanMessage(\n", |
| 268 | + " content=\"Hey, book a table at the nearest best Chinese restaurant for 8:00pm\"\n", |
| 269 | + " ),\n", |
| 270 | + " AIMessage(\n", |
| 271 | + " content=\"Sure, let me find the best options for you.\",\n", |
| 272 | + " tool_calls=[\n", |
| 273 | + " ToolCall(\n", |
| 274 | + " name=\"restaurant_search\",\n", |
| 275 | + " args={\"cuisine\": \"Chinese\", \"time\": \"8:00pm\"},\n", |
| 276 | + " )\n", |
| 277 | + " ],\n", |
| 278 | + " ),\n", |
| 279 | + " ToolMessage(content=\"Found a few options: 1. Golden Dragon, 2. Jade Palace\"),\n", |
| 280 | + " AIMessage(\n", |
| 281 | + " content=\"I found some great options: Golden Dragon and Jade Palace. Which one would you prefer?\"\n", |
| 282 | + " ),\n", |
| 283 | + " HumanMessage(content=\"Let's go with Golden Dragon.\"),\n", |
| 284 | + " AIMessage(\n", |
| 285 | + " content=\"Great choice! I'll book a table for 8:00pm at Golden Dragon.\",\n", |
| 286 | + " tool_calls=[\n", |
| 287 | + " ToolCall(\n", |
| 288 | + " name=\"restaurant_book\",\n", |
| 289 | + " args={\"name\": \"Golden Dragon\", \"time\": \"8:00pm\"},\n", |
| 290 | + " )\n", |
| 291 | + " ],\n", |
| 292 | + " ),\n", |
| 293 | + " ToolMessage(content=\"Table booked at Golden Dragon for 8:00pm.\"),\n", |
| 294 | + " AIMessage(\n", |
| 295 | + " content=\"Your table at Golden Dragon is booked for 8:00pm. Enjoy your meal!\"\n", |
| 296 | + " ),\n", |
| 297 | + " HumanMessage(content=\"thanks\"),\n", |
| 298 | + " ]\n", |
| 299 | + ")" |
273 | 300 | ] |
274 | 301 | }, |
275 | 302 | { |
|
0 commit comments