Skip to content

Commit 98b4652

Browse files
committed
fix conflicts
2 parents 6dc22ef + d3b8300 commit 98b4652

40 files changed

+2200
-2128
lines changed

README.md

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,23 @@ pip install guardrails-ai
6464
3. Create a Guard from the installed guardrail.
6565

6666
```python
67-
# Import Guard and Validator
68-
from guardrails.hub import RegexMatch
6967
from guardrails import Guard
68+
from guardrails.hub import RegexMatch
7069
71-
# Initialize the Guard with
72-
val = Guard().use(
73-
RegexMatch(regex="^[A-Z][a-z]*$")
70+
guard = Guard().use(
71+
RegexMatch, regex="\(?\d{3}\)?-? *\d{3}-? *-?\d{4}", on_fail="exception"
7472
)
7573
76-
guard.parse("Caesar") # Guardrail Passes
77-
guard.parse("Caesar is a great leader") # Guardrail Fails
74+
guard.validate("123-456-7890") # Guardrail passes
75+
76+
try:
77+
guard.validate("1234-789-0000") # Guardrail fails
78+
except Exception as e:
79+
print(e)
80+
```
81+
Output:
82+
```console
83+
Validation failed for field with errors: Result must match \(?\d{3}\)?-? *\d{3}-? *-?\d{4}
7884
```
7985
4. Run multiple guardrails within a Guard.
8086
First, install the necessary guardrails from Guardrails Hub.
@@ -87,18 +93,32 @@ pip install guardrails-ai
8793
Then, create a Guard from the installed guardrails.
8894

8995
```python
90-
from guardrails.hub import RegexMatch, ValidLength
9196
from guardrails import Guard
97+
from guardrails.hub import CompetitorCheck, ToxicLanguage
9298
93-
guard = Guard().use(
94-
RegexMatch(regex="^[A-Z][a-z]*$"),
95-
ValidLength(min=1, max=32)
99+
guard = Guard().use_many(
100+
CompetitorCheck(["Apple", "Microsoft", "Google"], on_fail="exception"),
101+
ToxicLanguage(threshold=0.5, validation_method="sentence", on_fail="exception"),
96102
)
97103
98-
guard.parse("Caesar") # Guardrail Passes
99-
guard.parse("Caesar is a great leader") # Guardrail Fails
104+
guard.validate(
105+
"""An apple a day keeps a doctor away.
106+
This is good advice for keeping your health."""
107+
) # Both the guardrails pass
108+
109+
try:
110+
guard.validate(
111+
"""Shut the hell up! Apple just released a new iPhone."""
112+
) # Both the guardrails fail
113+
except Exception as e:
114+
print(e)
100115
```
116+
Output:
117+
```console
118+
Validation failed for field with errors: Found the following competitors: [['Apple']]. Please avoid naming those competitors next time, The following sentences in your response were found to be toxic:
101119
120+
- Shut the hell up!
121+
```
102122
103123
### Use Guardrails to generate structured data from LLMs
104124
@@ -133,7 +153,7 @@ validated_output, *rest = guard(
133153
engine="gpt-3.5-turbo-instruct"
134154
)
135155
136-
print(f"{validated_output}")
156+
print(validated_output)
137157
```
138158
139159
This prints:

docs/examples/generate_structured_data_cohere.ipynb

Lines changed: 89 additions & 105 deletions
Large diffs are not rendered by default.

docs/how_to_guides/streaming.ipynb

Lines changed: 347 additions & 465 deletions
Large diffs are not rendered by default.

docs/hub/api_reference_markdown/validators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ Validates whether the generated code snippet contains any secrets.
165165
```py
166166

167167
guard = Guard.from_string(validators=[
168-
DetectSecrets(on_fail="fix")
168+
DetectSecrets(on_fail=OnFailAction.FIX)
169169
])
170170
guard.parse(
171171
llm_output=code_snippet,

docs/llm_api_wrappers.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@ guard = Guard.from_string(
134134
validators=[
135135
ValidLength(
136136
min=48,
137-
on_fail="fix"
137+
on_fail=OnFailAction.FIX
138138
),
139139
ToxicLanguage(
140-
on_fail="fix"
140+
on_fail=OnFailAction.FIX
141141
)
142142
],
143143
prompt=prompt
@@ -179,10 +179,10 @@ guard = Guard.from_string(
179179
validators=[
180180
ValidLength(
181181
min=48,
182-
on_fail="fix"
182+
on_fail=OnFailAction.FIX
183183
),
184184
ToxicLanguage(
185-
on_fail="fix"
185+
on_fail=OnFailAction.FIX
186186
)
187187
],
188188
prompt=prompt

guardrails/cli/hub/install.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import os
32
import subprocess
43
import sys
@@ -34,7 +33,7 @@ def pip_process(
3433
package: str = "",
3534
flags: List[str] = [],
3635
format: Union[Literal["string"], Literal["json"]] = string_format,
37-
):
36+
) -> Union[str, dict]:
3837
try:
3938
logger.debug(f"running pip {action} {' '.join(flags)} {package}")
4039
command = [sys.executable, "-m", "pip", action]
@@ -44,7 +43,11 @@ def pip_process(
4443
output = subprocess.check_output(command)
4544
logger.debug(f"decoding output from pip {action} {package}")
4645
if format == json_format:
47-
return BytesHeaderParser().parsebytes(output)
46+
parsed = BytesHeaderParser().parsebytes(output)
47+
accumulator = {}
48+
for key, value in parsed.items():
49+
accumulator[key] = value
50+
return accumulator
4851
return str(output.decode())
4952
except subprocess.CalledProcessError as exc:
5053
logger.error(
@@ -197,9 +200,14 @@ def install_hub_module(module_manifest: ModuleManifest, site_packages: str):
197200
inspect_output = pip_process(
198201
"inspect", flags=[f"--path={install_directory}"], format=json_format
199202
)
200-
inspection: dict = json.loads(str(inspect_output))
203+
204+
# throw if inspect_output is a string. Mostly for pyright
205+
if isinstance(inspect_output, str):
206+
logger.error("Failed to inspect the installed package!")
207+
sys.exit(1)
208+
201209
dependencies = (
202-
Stack(*inspection.get("installed", []))
210+
Stack(*inspect_output.get("installed", []))
203211
.at(0, {})
204212
.get("metadata", {}) # type: ignore
205213
.get("requires_dist", []) # type: ignore

guardrails/guard.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def __call__(
467467
stream: Optional[bool] = False,
468468
*args,
469469
**kwargs,
470-
) -> Union[ValidationOutcome[OT], Iterable[str]]:
470+
) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]:
471471
...
472472

473473
@overload
@@ -499,7 +499,8 @@ def __call__(
499499
*args,
500500
**kwargs,
501501
) -> Union[
502-
Union[ValidationOutcome[OT], Iterable[str]], Awaitable[ValidationOutcome[OT]]
502+
Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]],
503+
Awaitable[ValidationOutcome[OT]],
503504
]:
504505
"""Call the LLM and validate the output. Pass an async LLM API to
505506
return a coroutine.
@@ -663,7 +664,7 @@ def _call_sync(
663664
call_log: Call,
664665
*args,
665666
**kwargs,
666-
) -> Union[ValidationOutcome[OT], Iterable[str]]:
667+
) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]:
667668
instructions_obj = instructions or self.instructions
668669
prompt_obj = prompt or self.prompt
669670
msg_history_obj = msg_history or []

guardrails/llm_providers.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,25 @@ def _invoke_llm(
280280
if "instructions" in kwargs:
281281
prompt = kwargs.pop("instructions") + "\n\n" + prompt
282282

283+
def is_base_cohere_chat(func):
284+
try:
285+
return (
286+
func.__closure__[1].cell_contents.__func__.__qualname__
287+
== "BaseCohere.chat"
288+
)
289+
except (AttributeError, IndexError):
290+
return False
291+
292+
# TODO: When cohere totally gets rid of `generate`,
293+
# remove this cond and the final return
294+
if is_base_cohere_chat(client_callable):
295+
cohere_response = client_callable(
296+
message=prompt, model=model, *args, **kwargs
297+
)
298+
return LLMResponse(
299+
output=cohere_response.text,
300+
)
301+
283302
cohere_response = client_callable(prompt=prompt, model=model, *args, **kwargs)
284303
return LLMResponse(
285304
output=cohere_response[0].text,
@@ -562,7 +581,7 @@ def get_llm_ask(llm_api: Callable, *args, **kwargs) -> PromptCallableBase:
562581
if (
563582
isinstance(getattr(llm_api, "__self__", None), cohere.Client)
564583
and getattr(llm_api, "__name__", None) == "generate"
565-
):
584+
) or getattr(llm_api, "__module__", None) == "cohere.client":
566585
return CohereCallable(*args, client_callable=llm_api, **kwargs)
567586
except ImportError:
568587
pass

0 commit comments

Comments
 (0)