Skip to content

Commit 33428c6

Browse files
authored
Validator refactor (#478)
* refactor validators to multiple files * complete validator split
1 parent ee18b8c commit 33428c6

32 files changed

+2953
-2540
lines changed

guardrails/validators.py

Lines changed: 0 additions & 2540 deletions
This file was deleted.

guardrails/validators/__init__.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""This module contains the validators for the Guardrails framework.
2+
3+
The name with which a validator is registered is the name that is used
4+
in the `RAIL` spec to specify formatters.
5+
"""
6+
7+
from guardrails.validator_base import (
8+
FailResult,
9+
PassResult,
10+
ValidationResult,
11+
Validator,
12+
register_validator,
13+
)
14+
from guardrails.validators.bug_free_python import BugFreePython
15+
from guardrails.validators.bug_free_sql import BugFreeSQL
16+
from guardrails.validators.detect_secrets import DetectSecrets, detect_secrets
17+
from guardrails.validators.endpoint_is_reachable import EndpointIsReachable
18+
from guardrails.validators.ends_with import EndsWith
19+
from guardrails.validators.exclude_sql_predicates import ExcludeSqlPredicates
20+
from guardrails.validators.extracted_summary_sentences_match import (
21+
ExtractedSummarySentencesMatch,
22+
)
23+
from guardrails.validators.extractive_summary import ExtractiveSummary
24+
from guardrails.validators.is_high_quality_translation import IsHighQualityTranslation
25+
from guardrails.validators.is_profanity_free import IsProfanityFree
26+
from guardrails.validators.lower_case import LowerCase
27+
from guardrails.validators.one_line import OneLine
28+
from guardrails.validators.pii_filter import AnalyzerEngine, AnonymizerEngine, PIIFilter
29+
from guardrails.validators.provenance import ProvenanceV0, ProvenanceV1
30+
from guardrails.validators.pydantic_field_validator import PydanticFieldValidator
31+
from guardrails.validators.qa_relevance_llm_eval import QARelevanceLLMEval
32+
from guardrails.validators.reading_time import ReadingTime
33+
from guardrails.validators.regex_match import RegexMatch
34+
from guardrails.validators.remove_redundant_sentences import RemoveRedundantSentences
35+
from guardrails.validators.saliency_check import SaliencyCheck
36+
from guardrails.validators.similar_to_document import SimilarToDocument
37+
from guardrails.validators.similar_to_list import SimilarToList
38+
from guardrails.validators.sql_column_presence import SqlColumnPresence
39+
from guardrails.validators.two_words import TwoWords
40+
from guardrails.validators.upper_case import UpperCase
41+
from guardrails.validators.valid_choices import ValidChoices
42+
from guardrails.validators.valid_length import ValidLength
43+
from guardrails.validators.valid_range import ValidRange
44+
from guardrails.validators.valid_url import ValidURL
45+
46+
__all__ = [
47+
# Validators
48+
"PydanticFieldValidator",
49+
"ValidRange",
50+
"ValidChoices",
51+
"LowerCase",
52+
"UpperCase",
53+
"ValidLength",
54+
"RegexMatch",
55+
"TwoWords",
56+
"OneLine",
57+
"ValidURL",
58+
"EndpointIsReachable",
59+
"BugFreePython",
60+
"BugFreeSQL",
61+
"SqlColumnPresence",
62+
"ExcludeSqlPredicates",
63+
"SimilarToDocument",
64+
"IsProfanityFree",
65+
"IsHighQualityTranslation",
66+
"EndsWith",
67+
"ExtractedSummarySentencesMatch",
68+
"ReadingTime",
69+
"ExtractiveSummary",
70+
"RemoveRedundantSentences",
71+
"SaliencyCheck",
72+
"QARelevanceLLMEval",
73+
"ProvenanceV0",
74+
"ProvenanceV1",
75+
"PIIFilter",
76+
"SimilarToList",
77+
"DetectSecrets",
78+
# Validator helpers
79+
"detect_secrets",
80+
"AnalyzerEngine",
81+
"AnonymizerEngine",
82+
# Base classes
83+
"Validator",
84+
"register_validator",
85+
"ValidationResult",
86+
"PassResult",
87+
"FailResult",
88+
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import ast
2+
import logging
3+
from typing import Any, Dict
4+
5+
from guardrails.validator_base import (
6+
FailResult,
7+
PassResult,
8+
ValidationResult,
9+
Validator,
10+
register_validator,
11+
)
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
@register_validator(name="bug-free-python", data_type="string")
17+
class BugFreePython(Validator):
18+
"""Validates that there are no Python syntactic bugs in the generated code.
19+
20+
This validator checks for syntax errors by running `ast.parse(code)`,
21+
and will raise an exception if there are any.
22+
Only the packages in the `python` environment are available to the code snippet.
23+
24+
**Key Properties**
25+
26+
| Property | Description |
27+
| ----------------------------- | --------------------------------- |
28+
| Name for `format` attribute | `bug-free-python` |
29+
| Supported data types | `string` |
30+
| Programmatic fix | None |
31+
"""
32+
33+
def validate(self, value: Any, metadata: Dict) -> ValidationResult:
34+
logger.debug(f"Validating {value} is not a bug...")
35+
36+
# The value is a Python code snippet. We need to check for syntax errors.
37+
try:
38+
ast.parse(value)
39+
except SyntaxError as e:
40+
return FailResult(
41+
error_message=f"Syntax error: {e.msg}",
42+
)
43+
44+
return PassResult()
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Any, Callable, Dict, Optional
2+
3+
from guardrails.utils.sql_utils import SQLDriver, create_sql_driver
4+
from guardrails.validator_base import (
5+
FailResult,
6+
PassResult,
7+
ValidationResult,
8+
Validator,
9+
register_validator,
10+
)
11+
12+
13+
@register_validator(name="bug-free-sql", data_type=["string"])
14+
class BugFreeSQL(Validator):
15+
"""Validates that there are no SQL syntactic bugs in the generated code.
16+
17+
This is a very minimal implementation that uses the Pypi `sqlvalidator` package
18+
to check if the SQL query is valid. You can implement a custom SQL validator
19+
that uses a database connection to check if the query is valid.
20+
21+
**Key Properties**
22+
23+
| Property | Description |
24+
| ----------------------------- | --------------------------------- |
25+
| Name for `format` attribute | `bug-free-sql` |
26+
| Supported data types | `string` |
27+
| Programmatic fix | None |
28+
"""
29+
30+
def __init__(
31+
self,
32+
conn: Optional[str] = None,
33+
schema_file: Optional[str] = None,
34+
on_fail: Optional[Callable] = None,
35+
):
36+
super().__init__(on_fail=on_fail)
37+
self._driver: SQLDriver = create_sql_driver(schema_file=schema_file, conn=conn)
38+
39+
def validate(self, value: Any, metadata: Dict) -> ValidationResult:
40+
errors = self._driver.validate_sql(value)
41+
if len(errors) > 0:
42+
return FailResult(
43+
error_message=". ".join(errors),
44+
)
45+
46+
return PassResult()
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import os
2+
import warnings
3+
from typing import Any, Callable, Dict, List, Tuple, Union
4+
5+
from guardrails.validator_base import (
6+
FailResult,
7+
PassResult,
8+
ValidationResult,
9+
Validator,
10+
register_validator,
11+
)
12+
13+
try:
14+
import detect_secrets # type: ignore
15+
except ImportError:
16+
detect_secrets = None
17+
18+
19+
@register_validator(name="detect-secrets", data_type="string")
20+
class DetectSecrets(Validator):
21+
"""Validates whether the generated code snippet contains any secrets.
22+
23+
**Key Properties**
24+
| Property | Description |
25+
| ----------------------------- | --------------------------------- |
26+
| Name for `format` attribute | `detect-secrets` |
27+
| Supported data types | `string` |
28+
| Programmatic fix | None |
29+
30+
Parameters: Arguments
31+
None
32+
33+
This validator uses the detect-secrets library to check whether the generated code
34+
snippet contains any secrets. If any secrets are detected, the validator fails and
35+
returns the generated code snippet with the secrets replaced with asterisks.
36+
Else the validator returns the generated code snippet.
37+
38+
Following are some caveats:
39+
- Multiple secrets on the same line may not be caught. e.g.
40+
- Minified code
41+
- One-line lists/dictionaries
42+
- Multi-variable assignments
43+
- Multi-line secrets may not be caught. e.g.
44+
- RSA/SSH keys
45+
46+
Example:
47+
```py
48+
49+
guard = Guard.from_string(validators=[
50+
DetectSecrets(on_fail="fix")
51+
])
52+
guard.parse(
53+
llm_output=code_snippet,
54+
)
55+
"""
56+
57+
def __init__(self, on_fail: Union[Callable[..., Any], None] = None, **kwargs):
58+
super().__init__(on_fail, **kwargs)
59+
60+
# Check if detect-secrets is installed
61+
if detect_secrets is None:
62+
raise ValueError(
63+
"You must install detect-secrets in order to "
64+
"use the DetectSecrets validator."
65+
)
66+
self.temp_file_name = "temp.txt"
67+
self.mask = "********"
68+
69+
def get_unique_secrets(self, value: str) -> Tuple[Dict[str, Any], List[str]]:
70+
"""Get unique secrets from the value.
71+
72+
Args:
73+
value (str): The generated code snippet.
74+
75+
Returns:
76+
unique_secrets (Dict[str, Any]): A dictionary of unique secrets and their
77+
line numbers.
78+
lines (List[str]): The lines of the generated code snippet.
79+
"""
80+
try:
81+
# Write each line of value to a new file
82+
with open(self.temp_file_name, "w") as f:
83+
f.writelines(value)
84+
except Exception as e:
85+
raise OSError(
86+
"Problems creating or deleting the temporary file. "
87+
"Please check the permissions of the current directory."
88+
) from e
89+
90+
try:
91+
# Create a new secrets collection
92+
from detect_secrets import settings
93+
from detect_secrets.core.secrets_collection import SecretsCollection
94+
95+
secrets = SecretsCollection()
96+
97+
# Scan the file for secrets
98+
with settings.default_settings():
99+
secrets.scan_file(self.temp_file_name)
100+
except ImportError:
101+
raise ValueError(
102+
"You must install detect-secrets in order to "
103+
"use the DetectSecrets validator."
104+
)
105+
except Exception as e:
106+
raise RuntimeError(
107+
"Problems with creating a SecretsCollection or "
108+
"scanning the file for secrets."
109+
) from e
110+
111+
# Get unique secrets from these secrets
112+
unique_secrets = {}
113+
for secret in secrets:
114+
_, potential_secret = secret
115+
actual_secret = potential_secret.secret_value
116+
line_number = potential_secret.line_number
117+
if actual_secret not in unique_secrets:
118+
unique_secrets[actual_secret] = [line_number]
119+
else:
120+
# if secret already exists, avoid duplicate line numbers
121+
if line_number not in unique_secrets[actual_secret]:
122+
unique_secrets[actual_secret].append(line_number)
123+
124+
try:
125+
# File no longer needed, read the lines from the file
126+
with open(self.temp_file_name, "r") as f:
127+
lines = f.readlines()
128+
except Exception as e:
129+
raise OSError(
130+
"Problems reading the temporary file. "
131+
"Please check the permissions of the current directory."
132+
) from e
133+
134+
try:
135+
# Delete the file
136+
os.remove(self.temp_file_name)
137+
except Exception as e:
138+
raise OSError(
139+
"Problems deleting the temporary file. "
140+
"Please check the permissions of the current directory."
141+
) from e
142+
return unique_secrets, lines
143+
144+
def get_modified_value(
145+
self, unique_secrets: Dict[str, Any], lines: List[str]
146+
) -> str:
147+
"""Replace the secrets on the lines with asterisks.
148+
149+
Args:
150+
unique_secrets (Dict[str, Any]): A dictionary of unique secrets and their
151+
line numbers.
152+
lines (List[str]): The lines of the generated code snippet.
153+
154+
Returns:
155+
modified_value (str): The generated code snippet with secrets replaced with
156+
asterisks.
157+
"""
158+
# Replace the secrets on the lines with asterisks
159+
for secret, line_numbers in unique_secrets.items():
160+
for line_number in line_numbers:
161+
lines[line_number - 1] = lines[line_number - 1].replace(
162+
secret, self.mask
163+
)
164+
165+
# Convert lines to a multiline string
166+
modified_value = "".join(lines)
167+
return modified_value
168+
169+
def validate(self, value: str, metadata: Dict[str, Any]) -> ValidationResult:
170+
# Check if value is a multiline string
171+
if "\n" not in value:
172+
# Raise warning if value is not a multiline string
173+
warnings.warn(
174+
"The DetectSecrets validator works best with "
175+
"multiline code snippets. "
176+
"Refer validator docs for more details."
177+
)
178+
179+
# Add a newline to value
180+
value += "\n"
181+
182+
# Get unique secrets from the value
183+
unique_secrets, lines = self.get_unique_secrets(value)
184+
185+
if unique_secrets:
186+
# Replace the secrets on the lines with asterisks
187+
modified_value = self.get_modified_value(unique_secrets, lines)
188+
189+
return FailResult(
190+
error_message=(
191+
"The following secrets were detected in your response:\n"
192+
+ "\n".join(unique_secrets.keys())
193+
),
194+
fix_value=modified_value,
195+
)
196+
return PassResult()

0 commit comments

Comments
 (0)