Skip to content

Commit b1666cd

Browse files
committed
Add retry attempts, demo notebook
1 parent c1f4fbc commit b1666cd

File tree

2 files changed

+256
-7
lines changed

2 files changed

+256
-7
lines changed

docs/examples/response_is_on_topic.ipynb

Lines changed: 248 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,263 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# On Topic Validation"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"This validator checks if a text is related with a topic. Using a list of valid topics (which can include one or many) and optionally a list of invalid topics, it validates that the text's main topic is one of the valid ones. If none of the valid topics are relevant, the topic 'Other' will be considered as the most relevant one and the validator will fail.\n",
15+
"\n",
16+
"The validator supports 3 different variants:\n",
17+
"- Using an ensemble of Zero-Shot classifier + LLM fallback: if the original classification score is less than 0.5, an LLM is used to classify the main topic. This is the default behavior, setting `disable_classifier = False` and `disable_llm = False`.\n",
18+
"- Using just a Zero-Shot classifier to get the main topic (`disable_classifier = False` and `disable_llm = True`).\n",
19+
"- Using just an LLM to classify the main topic (`disable_classifier = True` and `disable_llm = False`).\n",
20+
"\n",
21+
"To use the LLM, you can pass in a name of any OpenAI ChatCompletion model like `gpt-3.5-turbo` or `gpt-4` as the `llm_callable`, or pass in a callable that handles LLM calls. This callable can use any LLM, that you define. For simplicity purposes, we show here a demo of using OpenAI's gpt-3.5-turbo model.\n",
22+
"\n",
23+
"To use the OpenAI API, you have 3 options:\n",
24+
"\n",
25+
"- Set the OPENAI_API_KEY environment variable: os.environ[\"OPENAI_API_KEY\"] = \"<OpenAI_API_KEY>\"\n",
26+
"- Set the OPENAI_API_KEY using openai.api_key=\"<OpenAI_API_KEY>\"\n",
27+
"- Pass the api_key as a parameter to the parse function as done below, in this example"
28+
]
29+
},
30+
{
31+
"cell_type": "markdown",
32+
"metadata": {},
33+
"source": [
34+
"## Set up a list of valid and invalid topics"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": 20,
40+
"metadata": {},
41+
"outputs": [],
42+
"source": [
43+
"valid_topics = [\"bike\"]\n",
44+
"invalid_topics = [\"phone\", \"tablet\", \"computer\"]"
45+
]
46+
},
47+
{
48+
"cell_type": "markdown",
49+
"metadata": {},
50+
"source": [
51+
"## Set up the target topic"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": 21,
57+
"metadata": {},
58+
"outputs": [],
59+
"source": [
60+
"text = \"\"\"Introducing the Galaxy Tab S7, a sleek and sophisticated device that seamlessly combines \\\n",
61+
"cutting-edge technology with unparalleled design. With a stunning 5.1-inch Quad HD Super AMOLED display, \\\n",
62+
"every detail comes to life in vibrant clarity. The Samsung Galaxy S7 boasts a powerful processor, \\\n",
63+
"ensuring swift and responsive performance for all your tasks. \\\n",
64+
"Capture your most cherished moments with the advanced camera system, which delivers stunning photos in any lighting conditions.\"\"\""
65+
]
66+
},
67+
{
68+
"cell_type": "markdown",
69+
"metadata": {},
70+
"source": [
71+
"## Set up the device\n",
72+
"\n",
73+
"The argument `device` is an ordinal to indicate CPU/GPU support for the Zero-shot classifier. Setting this to -1 (default) will leverage CPU, a positive will run the model on the associated CUDA device id."
74+
]
75+
},
376
{
477
"cell_type": "code",
5-
"execution_count": null,
78+
"execution_count": 22,
679
"metadata": {},
780
"outputs": [],
881
"source": [
9-
"print('hi')"
82+
"device = -1"
83+
]
84+
},
85+
{
86+
"cell_type": "markdown",
87+
"metadata": {},
88+
"source": [
89+
"## Set up the model\n",
90+
"\n",
91+
"The argument `model` indicates the model that will be used to classify the topic. See a list of all models [here](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=trending)."
92+
]
93+
},
94+
{
95+
"cell_type": "markdown",
96+
"metadata": {},
97+
"source": [
98+
"## Test the validator"
99+
]
100+
},
101+
{
102+
"cell_type": "markdown",
103+
"metadata": {},
104+
"source": [
105+
"### Version 1: Ensemble"
106+
]
107+
},
108+
{
109+
"cell_type": "markdown",
110+
"metadata": {},
111+
"source": [
112+
"Here, we use the text we defined above as an example llm output (llm_output). This sample text is about the topic 'tablet', which is explicitly mentioned in our 'invalid_topics' list. We expect the validator to fail."
113+
]
114+
},
115+
{
116+
"cell_type": "code",
117+
"execution_count": 23,
118+
"metadata": {},
119+
"outputs": [
120+
{
121+
"name": "stdout",
122+
"output_type": "stream",
123+
"text": [
124+
"Validation failed for field with errors: Most relevant topic is tablet.\n"
125+
]
126+
}
127+
],
128+
"source": [
129+
"import guardrails as gd\n",
130+
"from guardrails.validators import OnTopic\n",
131+
"\n",
132+
"# Create the Guard with the OnTopic Validator\n",
133+
"guard = gd.Guard.from_string(\n",
134+
" validators=[\n",
135+
" OnTopic(\n",
136+
" valid_topics=valid_topics,\n",
137+
" invalid_topics=invalid_topics,\n",
138+
" device=device,\n",
139+
" llm_callable=\"gpt-3.5-turbo\",\n",
140+
" disable_classifier=False,\n",
141+
" disable_llm=False,\n",
142+
" on_fail=\"exception\",\n",
143+
" )\n",
144+
" ],\n",
145+
")\n",
146+
"\n",
147+
"# Test with a given text\n",
148+
"output = guard.parse(\n",
149+
" llm_output=text,\n",
150+
")\n",
151+
"\n",
152+
"print(output.error)"
153+
]
154+
},
155+
{
156+
"cell_type": "markdown",
157+
"metadata": {},
158+
"source": [
159+
"### Version 2: Zero-Shot\n",
160+
"\n",
161+
"Here, we have disabled the LLM from running at all. We rely totally on what the Zero-Shot classifier outputs. We expect the validator again to fail."
162+
]
163+
},
164+
{
165+
"cell_type": "code",
166+
"execution_count": 24,
167+
"metadata": {},
168+
"outputs": [
169+
{
170+
"name": "stdout",
171+
"output_type": "stream",
172+
"text": [
173+
"Validation failed for field with errors: Most relevant topic is tablet.\n"
174+
]
175+
}
176+
],
177+
"source": [
178+
"# Create the Guard with the OnTopic Validator\n",
179+
"guard = gd.Guard.from_string(\n",
180+
" validators=[\n",
181+
" OnTopic(\n",
182+
" valid_topics=valid_topics,\n",
183+
" invalid_topics=invalid_topics,\n",
184+
" device=device,\n",
185+
" disable_classifier=False,\n",
186+
" disable_llm=True,\n",
187+
" on_fail=\"exception\",\n",
188+
" )\n",
189+
" ]\n",
190+
")\n",
191+
"\n",
192+
"# Test with a given text\n",
193+
"output = guard.parse(llm_output=text)\n",
194+
"\n",
195+
"print(output.error)"
196+
]
197+
},
198+
{
199+
"cell_type": "markdown",
200+
"metadata": {},
201+
"source": [
202+
"### Version 3: LLM\n",
203+
"\n",
204+
"We finally run the validator using the LLM alone, not as a backup to the zero-shot classifier. This cell expects an OPENAI_API_KEY to be present in as an env var. We again expect this cell to fail."
205+
]
206+
},
207+
{
208+
"cell_type": "code",
209+
"execution_count": 25,
210+
"metadata": {},
211+
"outputs": [
212+
{
213+
"name": "stdout",
214+
"output_type": "stream",
215+
"text": [
216+
"Validation failed for field with errors: Most relevant topic is tablet.\n"
217+
]
218+
}
219+
],
220+
"source": [
221+
"# Create the Guard with the OnTopic Validator\n",
222+
"guard = gd.Guard.from_string(\n",
223+
" validators=[\n",
224+
" OnTopic(\n",
225+
" valid_topics=valid_topics,\n",
226+
" invalid_topics=invalid_topics,\n",
227+
" llm_callable=\"gpt-3.5-turbo\",\n",
228+
" disable_classifier=True,\n",
229+
" disable_llm=False,\n",
230+
" on_fail=\"exception\",\n",
231+
" )\n",
232+
" ],\n",
233+
")\n",
234+
"\n",
235+
"# Test with a given text\n",
236+
"output = guard.parse(\n",
237+
" llm_output=text\n",
238+
")\n",
239+
"\n",
240+
"print(output.error)"
10241
]
11242
}
12243
],
13244
"metadata": {
245+
"kernelspec": {
246+
"display_name": ".venv",
247+
"language": "python",
248+
"name": "python3"
249+
},
14250
"language_info": {
15-
"name": "python"
251+
"codemirror_mode": {
252+
"name": "ipython",
253+
"version": 3
254+
},
255+
"file_extension": ".py",
256+
"mimetype": "text/x-python",
257+
"name": "python",
258+
"nbconvert_exporter": "python",
259+
"pygments_lexer": "ipython3",
260+
"version": "3.9.17"
16261
}
17262
},
18263
"nbformat": 4,

guardrails/validators/on_topic.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextvars
22
import json
3-
from typing import Any, Callable, List, Optional, Tuple, Union
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
44

55
import openai
66
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -103,7 +103,11 @@ def __init__(
103103
self._model = model
104104
self._disable_classifier = disable_classifier
105105
self._disable_llm = disable_llm
106-
self._model_threshold = model_threshold
106+
107+
if not model_threshold:
108+
model_threshold = 0.5
109+
else:
110+
self._model_threshold = model_threshold
107111

108112
self.set_callable(llm_callable)
109113

@@ -142,7 +146,7 @@ def set_client(self):
142146
openai.api_version = api_base
143147

144148
# todo: extract some of these similar methods into a base class w provenance
145-
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(0))
149+
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5))
146150
def call_llm(self, text: str, topics: List[str]) -> str:
147151
"""Call the LLM with the given prompt.
148152
@@ -216,7 +220,7 @@ def get_topic_zero_shot(
216220
score = result["scores"][0] # type: ignore
217221
return topic, score # type: ignore
218222

219-
def validate(self, value: str) -> ValidationResult:
223+
def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult:
220224
valid_topics = set(self._valid_topics)
221225
invalid_topics = set(self._invalid_topics)
222226

0 commit comments

Comments
 (0)