Skip to content

Commit 096a2d1

Browse files
Fix merge conflicts
2 parents fc0d718 + bcf8066 commit 096a2d1

File tree

12 files changed

+1028
-84
lines changed

12 files changed

+1028
-84
lines changed

.github/workflows/scripts/run_notebooks.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ cd docs/examples
99
# Function to process a notebook
1010
process_notebook() {
1111
notebook="$1"
12-
invalid_notebooks=("valid_chess_moves.ipynb" "translation_with_quality_check.ipynb" "llamaindex-output-parsing.ipynb")
12+
invalid_notebooks=("valid_chess_moves.ipynb" "translation_with_quality_check.ipynb" "llamaindex-output-parsing.ipynb" "competitors_check.ipynb")
1313
if [[ ! " ${invalid_notebooks[@]} " =~ " ${notebook} " ]]; then
1414
echo "Processing $notebook..."
1515
poetry run jupyter nbconvert --to notebook --execute "$notebook"
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"## Input Validation\n",
8+
"\n",
9+
"Guardrails supports validating inputs (prompts, instructions, msg_history) with string validators."
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"In XML, specify the validators on the `prompt` or `instructions` tag, as such:"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"metadata": {
23+
"is_executing": true
24+
},
25+
"outputs": [],
26+
"source": [
27+
"rail_spec = \"\"\"\n",
28+
"<rail version=\"0.1\">\n",
29+
"<prompt\n",
30+
" validators=\"two-words\"\n",
31+
" on-fail-two-words=\"exception\"\n",
32+
">\n",
33+
"This is not two words\n",
34+
"</prompt>\n",
35+
"<output type=\"string\">\n",
36+
"</output>\n",
37+
"</rail>\n",
38+
"\"\"\"\n",
39+
"\n",
40+
"from guardrails import Guard\n",
41+
"guard = Guard.from_rail_string(rail_spec)"
42+
]
43+
},
44+
{
45+
"cell_type": "markdown",
46+
"metadata": {},
47+
"source": [
48+
"When `fix` is specified as the on-fail handler, the prompt will automatically be amended before calling the LLM.\n",
49+
"\n",
50+
"In any other case (for example, `exception`), a `ValidationException` will be returned in the outcome."
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": null,
56+
"metadata": {
57+
"is_executing": true
58+
},
59+
"outputs": [],
60+
"source": [
61+
"import openai\n",
62+
"\n",
63+
"outcome = guard(\n",
64+
" openai.ChatCompletion.create,\n",
65+
")\n",
66+
"outcome.error"
67+
]
68+
},
69+
{
70+
"cell_type": "markdown",
71+
"metadata": {},
72+
"source": [
73+
"When using pydantic to initialize a `Guard`, input validators can be specified by composition, as such:"
74+
]
75+
},
76+
{
77+
"cell_type": "code",
78+
"execution_count": null,
79+
"metadata": {},
80+
"outputs": [],
81+
"source": [
82+
"from guardrails.validators import TwoWords\n",
83+
"from pydantic import BaseModel\n",
84+
"\n",
85+
"\n",
86+
"class Pet(BaseModel):\n",
87+
" name: str\n",
88+
" age: int\n",
89+
"\n",
90+
"\n",
91+
"guard = Guard.from_pydantic(Pet)\n",
92+
"guard.with_prompt_validation([TwoWords(on_fail=\"exception\")])\n",
93+
"\n",
94+
"outcome = guard(\n",
95+
" openai.ChatCompletion.create,\n",
96+
" prompt=\"This is not two words\",\n",
97+
")\n",
98+
"outcome.error"
99+
]
100+
}
101+
],
102+
"metadata": {
103+
"kernelspec": {
104+
"display_name": "Python 3 (ipykernel)",
105+
"language": "python",
106+
"name": "python3"
107+
},
108+
"language_info": {
109+
"codemirror_mode": {
110+
"name": "ipython",
111+
"version": 3
112+
},
113+
"file_extension": ".py",
114+
"mimetype": "text/x-python",
115+
"name": "python",
116+
"nbconvert_exporter": "python",
117+
"pygments_lexer": "ipython3",
118+
"version": "3.11.0"
119+
}
120+
},
121+
"nbformat": 4,
122+
"nbformat_minor": 1
123+
}

guardrails/classes/history/call.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,13 @@ def error(self) -> Optional[str]:
281281
return None
282282
return self.iterations.last.error # type: ignore
283283

284+
@property
285+
def exception(self) -> Optional[Exception]:
286+
"""The exception that interrupted the run."""
287+
if self.iterations.empty():
288+
return None
289+
return self.iterations.last.exception # type: ignore
290+
284291
@property
285292
def failed_validations(self) -> Stack[ValidatorLogs]:
286293
"""The validator logs for any validations that failed during the

guardrails/classes/history/iteration.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ def error(self) -> Optional[str]:
113113
this iteration."""
114114
return self.outputs.error
115115

116+
@property
117+
def exception(self) -> Optional[Exception]:
118+
"""The exception that interrupted this iteration."""
119+
return self.outputs.exception
120+
116121
@property
117122
def failed_validations(self) -> List[ValidatorLogs]:
118123
"""The validator logs for any validations that failed during this

guardrails/classes/history/outputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class Outputs(ArbitraryModel):
4646
"that raised and interrupted the process.",
4747
default=None,
4848
)
49+
exception: Optional[Exception] = Field(
50+
description="The exception that interrupted the process.", default=None
51+
)
4952

5053
def _all_empty(self) -> bool:
5154
return (

guardrails/datatypes.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,11 @@ def collect_validation(
395395
schema: Dict,
396396
) -> FieldValidation:
397397
# Validators in the main list data type are applied to the list overall.
398-
399398
validation = self._constructor_validation(key, value)
400399

400+
if value is None and self.optional:
401+
return validation
402+
401403
if len(self._children) == 0:
402404
return validation
403405

@@ -435,9 +437,11 @@ def collect_validation(
435437
schema: Dict,
436438
) -> FieldValidation:
437439
# Validators in the main object data type are applied to the object overall.
438-
439440
validation = self._constructor_validation(key, value)
440441

442+
if value is None and self.optional:
443+
return validation
444+
441445
if len(self._children) == 0:
442446
return validation
443447

guardrails/guard.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import contextvars
3+
import warnings
34
from typing import (
45
Any,
56
Awaitable,
@@ -28,7 +29,7 @@
2829
from guardrails.prompt import Instructions, Prompt
2930
from guardrails.rail import Rail
3031
from guardrails.run import AsyncRunner, Runner, StreamRunner
31-
from guardrails.schema import Schema
32+
from guardrails.schema import Schema, StringSchema
3233
from guardrails.validators import Validator
3334

3435
add_destinations(logger.debug)
@@ -65,9 +66,19 @@ def __init__(
6566
self.base_model = base_model
6667

6768
@property
68-
def input_schema(self) -> Optional[Schema]:
69+
def prompt_schema(self) -> Optional[StringSchema]:
6970
"""Return the input schema."""
70-
return self.rail.input_schema
71+
return self.rail.prompt_schema
72+
73+
@property
74+
def instructions_schema(self) -> Optional[StringSchema]:
75+
"""Return the input schema."""
76+
return self.rail.instructions_schema
77+
78+
@property
79+
def msg_history_schema(self) -> Optional[StringSchema]:
80+
"""Return the input schema."""
81+
return self.rail.msg_history_schema
7182

7283
@property
7384
def output_schema(self) -> Schema:
@@ -458,7 +469,9 @@ async def _call_async(
458469
prompt=prompt_obj,
459470
msg_history=msg_history_obj,
460471
api=get_async_llm_ask(llm_api, *args, **kwargs),
461-
input_schema=self.input_schema,
472+
prompt_schema=self.prompt_schema,
473+
instructions_schema=self.instructions_schema,
474+
msg_history_schema=self.msg_history_schema,
462475
output_schema=self.output_schema,
463476
num_reasks=num_reasks,
464477
metadata=metadata,
@@ -634,7 +647,9 @@ def _sync_parse(
634647
prompt=kwargs.pop("prompt", None),
635648
msg_history=kwargs.pop("msg_history", None),
636649
api=get_llm_ask(llm_api, *args, **kwargs) if llm_api else None,
637-
input_schema=None,
650+
prompt_schema=self.prompt_schema,
651+
instructions_schema=self.instructions_schema,
652+
msg_history_schema=self.msg_history_schema,
638653
output_schema=self.output_schema,
639654
num_reasks=num_reasks,
640655
metadata=metadata,
@@ -674,7 +689,9 @@ async def _async_parse(
674689
prompt=kwargs.pop("prompt", None),
675690
msg_history=kwargs.pop("msg_history", None),
676691
api=get_async_llm_ask(llm_api, *args, **kwargs) if llm_api else None,
677-
input_schema=None,
692+
prompt_schema=self.prompt_schema,
693+
instructions_schema=self.instructions_schema,
694+
msg_history_schema=self.msg_history_schema,
678695
output_schema=self.output_schema,
679696
num_reasks=num_reasks,
680697
metadata=metadata,
@@ -687,3 +704,54 @@ async def _async_parse(
687704
)
688705

689706
return ValidationOutcome[OT].from_guard_history(call, error_message)
707+
708+
def with_prompt_validation(
709+
self,
710+
validators: Sequence[Validator],
711+
):
712+
"""Add prompt validation to the Guard.
713+
714+
Args:
715+
validators: The validators to add to the prompt.
716+
"""
717+
if self.rail.prompt_schema:
718+
warnings.warn("Overriding existing prompt validators.")
719+
schema = StringSchema.from_string(
720+
validators=validators,
721+
)
722+
self.rail.prompt_schema = schema
723+
return self
724+
725+
def with_instructions_validation(
726+
self,
727+
validators: Sequence[Validator],
728+
):
729+
"""Add instructions validation to the Guard.
730+
731+
Args:
732+
validators: The validators to add to the instructions.
733+
"""
734+
if self.rail.instructions_schema:
735+
warnings.warn("Overriding existing instructions validators.")
736+
schema = StringSchema.from_string(
737+
validators=validators,
738+
)
739+
self.rail.instructions_schema = schema
740+
return self
741+
742+
def with_msg_history_validation(
743+
self,
744+
validators: Sequence[Validator],
745+
):
746+
"""Add msg_history validation to the Guard.
747+
748+
Args:
749+
validators: The validators to add to the msg_history.
750+
"""
751+
if self.rail.msg_history_schema:
752+
warnings.warn("Overriding existing msg_history validators.")
753+
schema = StringSchema.from_string(
754+
validators=validators,
755+
)
756+
self.rail.msg_history_schema = schema
757+
return self

0 commit comments

Comments
 (0)