Skip to content

Commit 5bce4de

Browse files
committed
Merge branch 'main' into 0.3.0
2 parents 8022c6f + ff395e8 commit 5bce4de

File tree

19 files changed

+508
-25
lines changed

19 files changed

+508
-25
lines changed

.github/workflows/deploy_docs.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,18 @@ jobs:
3333
uses: actions/cache@v3
3434
with:
3535
path: ~/.cache/pypoetry
36-
key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }
36+
key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }}
3737
- name: Install Poetry
3838
uses: snok/install-poetry@v1
3939
- name: Install dependencies
4040
run: poetry install --with docs
4141
- name: Build
42-
run: mkdocs build
42+
run: poetry run mkdocs build
4343
- name: Upload artifact
4444
uses: actions/upload-pages-artifact@v2
4545
with:
4646
# Upload build folder
4747
path: 'site'
4848
- name: Deploy to GitHub Pages
4949
id: deployment
50-
uses: actions/deploy-pages@v2
50+
uses: actions/deploy-pages@v2

.github/workflows/scripts/run_notebooks.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ cd docs/examples
99
# Function to process a notebook
1010
process_notebook() {
1111
notebook="$1"
12-
invalid_notebooks=("valid_chess_moves.ipynb" "translation_with_quality_check.ipynb" "llamaindex-output-parsing.ipynb")
12+
invalid_notebooks=("valid_chess_moves.ipynb" "translation_with_quality_check.ipynb" "llamaindex-output-parsing.ipynb" "competitors_check.ipynb")
1313
if [[ ! " ${invalid_notebooks[@]} " =~ " ${notebook} " ]]; then
1414
echo "Processing $notebook..."
1515
poetry run jupyter nbconvert --to notebook --execute "$notebook"

docs/api_reference/validators.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
- "!validate"
1212
- "!register_validator"
1313
- "!PydanticReAsk"
14-
- "!Filter"
1514
- "!Refrain"
1615
- "!ValidationResult"
1716
- "!PassResult"

docs/examples/competitors_check.ipynb

Lines changed: 240 additions & 0 deletions
Large diffs are not rendered by default.

guardrails/utils/openai_utils/v0.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def construct_nonchat_response(
8484
) -> LLMResponse:
8585
"""Construct an LLMResponse from an OpenAI response.
8686
87-
Splits execution based on whether the `stream` parameter
88-
is set in the kwargs.
87+
Splits execution based on whether the `stream` parameter is set
88+
in the kwargs.
8989
"""
9090
if stream:
9191
# If stream is defined and set to True,
@@ -152,8 +152,8 @@ def construct_chat_response(
152152
) -> LLMResponse:
153153
"""Construct an LLMResponse from an OpenAI response.
154154
155-
Splits execution based on whether the `stream` parameter
156-
is set in the kwargs.
155+
Splits execution based on whether the `stream` parameter is set
156+
in the kwargs.
157157
"""
158158
if stream:
159159
# If stream is defined and set to True,
@@ -296,8 +296,8 @@ async def construct_chat_response(
296296
) -> LLMResponse:
297297
"""Construct an LLMResponse from an OpenAI response.
298298
299-
Splits execution based on whether the `stream` parameter
300-
is set in the kwargs.
299+
Splits execution based on whether the `stream` parameter is set
300+
in the kwargs.
301301
"""
302302
if stream:
303303
# If stream is defined and set to True,

guardrails/utils/openai_utils/v1.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def construct_nonchat_response(
7676
) -> LLMResponse:
7777
"""Construct an LLMResponse from an OpenAI response.
7878
79-
Splits execution based on whether the `stream` parameter
80-
is set in the kwargs.
79+
Splits execution based on whether the `stream` parameter is set
80+
in the kwargs.
8181
"""
8282
if stream:
8383
# If stream is defined and set to True,
@@ -140,8 +140,8 @@ def construct_chat_response(
140140
) -> LLMResponse:
141141
"""Construct an LLMResponse from an OpenAI response.
142142
143-
Splits execution based on whether the `stream` parameter
144-
is set in the kwargs.
143+
Splits execution based on whether the `stream` parameter is set
144+
in the kwargs.
145145
"""
146146
if stream:
147147
# If stream is defined and set to True,
@@ -298,8 +298,8 @@ async def construct_chat_response(
298298
) -> LLMResponse:
299299
"""Construct an LLMResponse from an OpenAI response.
300300
301-
Splits execution based on whether the `stream` parameter
302-
is set in the kwargs.
301+
Splits execution based on whether the `stream` parameter is set
302+
in the kwargs.
303303
"""
304304
if stream:
305305
# If stream is defined and set to True,

guardrails/validators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from guardrails.validators.bug_free_python import BugFreePython
1515
from guardrails.validators.bug_free_sql import BugFreeSQL
16+
from guardrails.validators.competitor_check import CompetitorCheck
1617
from guardrails.validators.detect_secrets import DetectSecrets, detect_secrets
1718
from guardrails.validators.endpoint_is_reachable import EndpointIsReachable
1819
from guardrails.validators.ends_with import EndsWith
@@ -75,6 +76,7 @@
7576
"PIIFilter",
7677
"SimilarToList",
7778
"DetectSecrets",
79+
"CompetitorCheck",
7880
# Validator helpers
7981
"detect_secrets",
8082
"AnalyzerEngine",
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import re
2+
from typing import Callable, Dict, List, Optional
3+
4+
from guardrails.logger import logger
5+
from guardrails.validator_base import (
6+
FailResult,
7+
PassResult,
8+
ValidationResult,
9+
Validator,
10+
register_validator,
11+
)
12+
13+
14+
try:
15+
import nltk # type: ignore
16+
except ImportError:
17+
nltk = None # type: ignore
18+
19+
if nltk is not None:
20+
try:
21+
nltk.data.find("tokenizers/punkt")
22+
except LookupError:
23+
nltk.download("punkt")
24+
25+
try:
26+
import spacy
27+
except ImportError:
28+
spacy = None
29+
30+
@register_validator(name="competitor-check", data_type="string")
31+
class CompetitorCheck(Validator):
32+
"""Validates that LLM-generated text is not naming any competitors from a
33+
given list.
34+
35+
In order to use this validator you need to provide an extensive list of the
36+
competitors you want to avoid naming including all common variations.
37+
38+
Args:
39+
competitors (List[str]): List of competitors you want to avoid naming
40+
"""
41+
42+
def __init__(
43+
self,
44+
competitors: List[str],
45+
on_fail: Optional[Callable] = None,
46+
):
47+
super().__init__(competitors=competitors, on_fail=on_fail)
48+
self._competitors = competitors
49+
model = "en_core_web_trf"
50+
if spacy is None:
51+
raise ImportError(
52+
"You must install spacy in order to use the CompetitorCheck validator."
53+
)
54+
55+
if not spacy.util.is_package(model):
56+
logger.info(
57+
f"Spacy model {model} not installed. "
58+
"Download should start now and take a few minutes."
59+
)
60+
spacy.cli.download(model) # type: ignore
61+
62+
self.nlp = spacy.load(model)
63+
64+
def exact_match(self, text: str, competitors: List[str]) -> List[str]:
65+
"""Performs exact match to find competitors from a list in a given
66+
text.
67+
68+
Args:
69+
text (str): The text to search for competitors.
70+
competitors (list): A list of competitor entities to match.
71+
72+
Returns:
73+
list: A list of matched entities.
74+
"""
75+
76+
found_entities = []
77+
for entity in competitors:
78+
pattern = rf"\b{re.escape(entity)}\b"
79+
match = re.search(pattern.lower(), text.lower())
80+
if match:
81+
found_entities.append(entity)
82+
return found_entities
83+
84+
def perform_ner(self, text: str, nlp) -> List[str]:
85+
"""Performs named entity recognition on text using a provided NLP
86+
model.
87+
88+
Args:
89+
text (str): The text to perform named entity recognition on.
90+
nlp: The NLP model to use for entity recognition.
91+
92+
Returns:
93+
entities: A list of entities found.
94+
"""
95+
96+
doc = nlp(text)
97+
entities = []
98+
for ent in doc.ents:
99+
entities.append(ent.text)
100+
return entities
101+
102+
def is_entity_in_list(self, entities: List[str], competitors: List[str]) -> List:
103+
"""Checks if any entity from a list is present in a given list of
104+
competitors.
105+
106+
Args:
107+
entities (list): A list of entities to check
108+
competitors (list): A list of competitor names to match
109+
110+
Returns:
111+
List: List of found competitors
112+
"""
113+
114+
found_competitors = []
115+
for entity in entities:
116+
for item in competitors:
117+
pattern = rf"\b{re.escape(item)}\b"
118+
match = re.search(pattern.lower(), entity.lower())
119+
if match:
120+
found_competitors.append(item)
121+
return found_competitors
122+
123+
def validate(self, value: str, metadata=Dict) -> ValidationResult:
124+
"""Checks a text to find competitors' names in it.
125+
126+
While running, store sentences naming competitors and generate a fixed output
127+
filtering out all flagged sentences.
128+
129+
Args:
130+
value (str): The value to be validated.
131+
metadata (Dict, optional): Additional metadata. Defaults to empty dict.
132+
133+
Returns:
134+
ValidationResult: The validation result.
135+
"""
136+
137+
if nltk is None:
138+
raise ImportError(
139+
"`nltk` library is required for `competitors-check` validator. "
140+
"Please install it with `poetry add nltk`."
141+
)
142+
sentences = nltk.sent_tokenize(value)
143+
flagged_sentences = []
144+
filtered_sentences = []
145+
list_of_competitors_found = []
146+
147+
for sentence in sentences:
148+
entities = self.exact_match(sentence, self._competitors)
149+
if entities:
150+
ner_entities = self.perform_ner(sentence, self.nlp)
151+
found_competitors = self.is_entity_in_list(ner_entities, entities)
152+
153+
if found_competitors:
154+
flagged_sentences.append((found_competitors, sentence))
155+
list_of_competitors_found.append(found_competitors)
156+
logger.debug(f"Found: {found_competitors} named in '{sentence}'")
157+
else:
158+
filtered_sentences.append(sentence)
159+
160+
else:
161+
filtered_sentences.append(sentence)
162+
163+
filtered_output = " ".join(filtered_sentences)
164+
165+
if len(flagged_sentences):
166+
return FailResult(
167+
error_message=(
168+
f"Found the following competitors: {list_of_competitors_found}. "
169+
"Please avoid naming those competitors next time"
170+
),
171+
fix_value=filtered_output,
172+
)
173+
else:
174+
return PassResult()

guardrails/validators/extractive_summary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def validate(self, value: Any, metadata: Dict) -> ValidationResult:
7373
except ImportError:
7474
raise ImportError(
7575
"`thefuzz` library is required for `extractive-summary` validator. "
76-
"Please install it with `pip install thefuzz`."
76+
"Please install it with `poetry add thefuzz`."
7777
)
7878

7979
# Split the value into sentences.

guardrails/validators/is_high_quality_translation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self, *args, **kwargs):
4545
except ImportError:
4646
raise ImportError(
4747
"`is-high-quality-translation` validator requires the `inspiredco`"
48-
"package. Please install it with `pip install inspiredco`."
48+
"package. Please install it with `poetry add inspiredco`."
4949
)
5050

5151
def validate(self, value: Any, metadata: Dict) -> ValidationResult:

0 commit comments

Comments
 (0)