Skip to content

Commit dcd9928

Browse files
wip: validation summary integration
1 parent 7e157bd commit dcd9928

File tree

5 files changed

+83
-2
lines changed

5 files changed

+83
-2
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# TODO Temp to update once generated class is in
2+
from typing import List, Optional
3+
4+
from guardrails.classes.generic.arbitrary_model import ArbitraryModel
5+
from guardrails.classes.validation.validation_result import ErrorSpan, FailResult
6+
from guardrails.classes.validation.validator_logs import ValidatorLogs
7+
8+
9+
class ValidationSummary(ArbitraryModel):
10+
validator_name: str
11+
validator_status: str
12+
failure_reason: Optional[str]
13+
error_spans: Optional[List["ErrorSpan"]] = []
14+
property_path: Optional[str]
15+
16+
@staticmethod
17+
def from_validator_logs(
18+
validator_logs: List[ValidatorLogs],
19+
) -> List["ValidationSummary"]:
20+
summaries = []
21+
for log in validator_logs:
22+
validation_result = log.validation_result
23+
is_fail_result = isinstance(validation_result, FailResult)
24+
failure_reason = validation_result.error_message if is_fail_result else None
25+
error_spans = validation_result.error_spans if is_fail_result else []
26+
summaries.append(
27+
ValidationSummary(
28+
validator_name=log.validator_name,
29+
validator_status=log.validation_result.outcome,
30+
property_path=log.property_path,
31+
failure_reason=failure_reason,
32+
error_spans=error_spans,
33+
)
34+
)
35+
return summaries

guardrails/classes/validation_outcome.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Generic, Iterator, Optional, Tuple, Union, cast
1+
from typing import Generic, Iterator, List, Optional, Tuple, Union, cast
22

33
from pydantic import Field
44
from rich.pretty import pretty_repr
@@ -11,6 +11,7 @@
1111
from guardrails.classes.history import Call, Iteration
1212
from guardrails.classes.output_type import OT
1313
from guardrails.classes.generic.arbitrary_model import ArbitraryModel
14+
from guardrails.classes.validation.validation_summary import ValidationSummary
1415
from guardrails.constants import pass_status
1516
from guardrails.utils.safe_get import safe_get
1617

@@ -31,6 +32,11 @@ class ValidationOutcome(IValidationOutcome, ArbitraryModel, Generic[OT]):
3132
error: If the validation failed, this field will contain the error message
3233
"""
3334

35+
validation_summaries: Optional[List["ValidationSummary"]] = Field(
36+
description="The summaries of the validation results.", default=[]
37+
)
38+
"""The summaries of the validation results."""
39+
3440
raw_llm_output: Optional[str] = Field(
3541
description="The raw, unchanged output from the LLM call.", default=None
3642
)
@@ -75,6 +81,8 @@ def from_guard_history(cls, call: Call):
7581
list(last_iteration.reasks), 0
7682
)
7783
validation_passed = call.status == pass_status
84+
validator_logs = last_iteration.validator_logs or []
85+
validation_summaries = ValidationSummary.from_validator_logs(validator_logs)
7886
reask = last_output if isinstance(last_output, ReAsk) else None
7987
error = call.error
8088
output = cast(OT, call.guarded_output)
@@ -84,6 +92,7 @@ def from_guard_history(cls, call: Call):
8492
validated_output=output,
8593
reask=reask,
8694
validation_passed=validation_passed,
95+
validation_summaries=validation_summaries,
8796
error=error,
8897
)
8998

guardrails/guard.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,11 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O
11931193
)
11941194
self.history.extend([Call.from_interface(call) for call in guard_history])
11951195

1196+
# TODO Validation Summary
1197+
# validator_logs = self.history.last.iterations.last.validator_logs
1198+
# validation_summaries = ValidationSummary.
1199+
# from_validator_logs(validator_logs)
1200+
11961201
# TODO: See if the below statement is still true
11971202
# Our interfaces are too different for this to work right now.
11981203
# Once we move towards shared interfaces for both the open source
@@ -1203,6 +1208,7 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O
12031208
if validation_output.validated_output
12041209
else None
12051210
)
1211+
# TODO: Validation Summary
12061212
return ValidationOutcome[OT](
12071213
call_id=validation_output.call_id, # type: ignore
12081214
raw_llm_output=validation_output.raw_llm_output,
@@ -1224,9 +1230,11 @@ def _stream_server_call(
12241230
payload=ValidatePayload.from_dict(payload), # type: ignore
12251231
openai_api_key=get_call_kwarg("api_key"),
12261232
)
1233+
print("Server response:", response)
12271234
for fragment in response:
12281235
validation_output = fragment
12291236
if validation_output is None:
1237+
# TODO Validation Summary
12301238
yield ValidationOutcome[OT](
12311239
call_id="0", # type: ignore
12321240
raw_llm_output=None,
@@ -1240,6 +1248,7 @@ def _stream_server_call(
12401248
if validation_output.validated_output
12411249
else None
12421250
)
1251+
# TODO Validation Summary
12431252
yield ValidationOutcome[OT](
12441253
call_id=validation_output.call_id, # type: ignore
12451254
raw_llm_output=validation_output.raw_llm_output,

guardrails/run/async_stream_runner.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from guardrails.classes import ValidationOutcome
1414
from guardrails.classes.history import Call, Inputs, Iteration, Outputs
1515
from guardrails.classes.output_type import OutputTypes
16+
from guardrails.classes.validation.validation_summary import ValidationSummary
1617
from guardrails.constants import pass_status
1718
from guardrails.llm_providers import (
1819
AsyncLiteLLMCallable,
@@ -164,11 +165,16 @@ async def async_step(
164165
)
165166
validation_response += cast(str, validated_fragment)
166167
passed = call_log.status == pass_status
168+
validator_logs = iteration.validator_logs
169+
validation_summaries = ValidationSummary.from_validator_logs(
170+
validator_logs
171+
)
167172
yield ValidationOutcome(
168173
call_id=call_log.id, # type: ignore
169174
raw_llm_output=chunk_text,
170175
validated_output=validated_fragment,
171176
validation_passed=passed,
177+
validation_summaries=validation_summaries,
172178
)
173179
else:
174180
async for chunk in stream_output:
@@ -204,11 +210,17 @@ async def async_step(
204210
validation_response = cast(list, validated_fragment)
205211
else:
206212
validation_response = cast(dict, validated_fragment)
213+
214+
validator_logs = iteration.validator_logs
215+
validation_summaries = ValidationSummary.from_validator_logs(
216+
validator_logs
217+
)
207218
yield ValidationOutcome(
208219
call_id=call_log.id, # type: ignore
209220
raw_llm_output=fragment,
210221
validated_output=chunk_text,
211222
validation_passed=validated_fragment is not None,
223+
validation_summaries=validation_summaries,
212224
)
213225

214226
iteration.outputs.raw_output = fragment

guardrails/run/stream_runner.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union, cast
22

3+
34
from guardrails import validator_service
45
from guardrails.classes.history import Call, Inputs, Iteration, Outputs
56
from guardrails.classes.output_type import OT, OutputTypes
7+
from guardrails.classes.validation.validation_summary import ValidationSummary
68
from guardrails.classes.validation_outcome import ValidationOutcome
79
from guardrails.llm_providers import (
810
LiteLLMCallable,
@@ -176,7 +178,9 @@ def prepare_chunk_generator(stream) -> Iterable[Tuple[Any, bool]]:
176178
"$",
177179
validate_subschema=True,
178180
)
179-
181+
# Not sure I like adding all this info to every chunk
182+
# maybe move last chunk?
183+
validator_logs = iteration.validator_logs
180184
for res in gen:
181185
chunk = res.chunk
182186
original_text = res.original_text
@@ -195,13 +199,19 @@ def prepare_chunk_generator(stream) -> Iterable[Tuple[Any, bool]]:
195199
)
196200
# 5. Convert validated fragment to a pretty JSON string
197201
validation_response += cast(str, chunk)
202+
validator_logs = call_log.iterations.last.validator_logs
203+
204+
validation_summaries = ValidationSummary.from_validator_logs(
205+
validator_logs
206+
)
198207
passed = call_log.status == pass_status
199208
yield ValidationOutcome(
200209
call_id=call_log.id, # type: ignore
201210
# The chunk or the whole output?
202211
raw_llm_output=original_text,
203212
validated_output=chunk,
204213
validation_passed=passed,
214+
validation_summaries=validation_summaries,
205215
)
206216

207217
# handle non string schema
@@ -246,11 +256,17 @@ def prepare_chunk_generator(stream) -> Iterable[Tuple[Any, bool]]:
246256
else:
247257
validation_response = cast(dict, validated_fragment)
248258
# 5. Convert validated fragment to a pretty JSON string
259+
260+
validator_logs = iteration.validator_logs
261+
validation_summaries = ValidationSummary.from_validator_logs(
262+
validator_logs
263+
)
249264
yield ValidationOutcome(
250265
call_id=call_log.id, # type: ignore
251266
raw_llm_output=fragment,
252267
validated_output=validated_fragment,
253268
validation_passed=validated_fragment is not None,
269+
validation_summaries=validation_summaries,
254270
)
255271

256272
# # Finally, add to logs

0 commit comments

Comments
 (0)