Skip to content

Commit beaed76

Browse files
authored
Merge pull request #1113 from guardrails-ai/async_streaming_fixes
Async streaming fixes to be consistent with sync streaming
2 parents ec75464 + a2e53ae commit beaed76

File tree

6 files changed

+215
-68
lines changed

6 files changed

+215
-68
lines changed

guardrails/classes/validation/validation_summary.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ class ValidationSummary(IValidationSummary, ArbitraryModel):
1212
def _generate_summaries_from_validator_logs(
1313
validator_logs: List[ValidatorLogs],
1414
) -> Iterator["ValidationSummary"]:
15-
"""
16-
Generate a list of ValidationSummary objects from a list of
17-
ValidatorLogs objects. Using an iterator to allow serializing
18-
the summaries to other formats.
15+
"""Generate a list of ValidationSummary objects from a list of
16+
ValidatorLogs objects.
17+
18+
Using an iterator to allow serializing the summaries to other
19+
formats.
1920
"""
2021
for log in validator_logs:
2122
validation_result = log.validation_result

guardrails/run/async_stream_runner.py

Lines changed: 100 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
cast,
99
)
1010

11-
11+
from guardrails.validator_service import AsyncValidatorService
1212
from guardrails.actions.reask import SkeletonReAsk
1313
from guardrails.classes import ValidationOutcome
1414
from guardrails.classes.history import Call, Inputs, Iteration, Outputs
1515
from guardrails.classes.output_type import OutputTypes
16-
from guardrails.constants import pass_status
1716
from guardrails.llm_providers import (
1817
AsyncLiteLLMCallable,
1918
AsyncPromptCallableBase,
@@ -28,6 +27,11 @@
2827
from guardrails.run.async_runner import AsyncRunner
2928
from guardrails.telemetry import trace_async_stream_step
3029
from guardrails.hub_telemetry.hub_tracing import async_trace_stream
30+
from guardrails.types import OnFailAction
31+
from guardrails.classes.validation.validation_result import (
32+
PassResult,
33+
FailResult,
34+
)
3135

3236

3337
class AsyncStreamRunner(AsyncRunner, StreamRunner):
@@ -133,49 +137,113 @@ async def async_step(
133137
parsed_fragment, validated_fragment, valid_op = None, None, None
134138
verified = set()
135139
validation_response = ""
140+
validation_progress = {}
141+
refrain_triggered = False
142+
validation_passed = True
136143

137144
if self.output_type == OutputTypes.STRING:
145+
validator_service = AsyncValidatorService(self.disable_tracer)
138146
async for chunk in stream_output:
139147
chunk_text = self.get_chunk_text(chunk, api)
140148
_ = self.is_last_chunk(chunk, api)
141-
fragment += chunk_text
142149

143-
parsed_chunk, move_to_next = self.parse(
144-
chunk_text, output_schema, verified=verified
145-
)
146-
if move_to_next:
147-
continue
148-
validated_fragment = await self.async_validate(
150+
fragment += chunk_text
151+
results = await validator_service.async_partial_validate(
152+
chunk_text,
153+
self.metadata,
154+
self.validation_map,
149155
iteration,
150-
index,
151-
parsed_chunk,
152-
output_schema,
153-
validate_subschema=True,
154-
stream=True,
156+
"$",
157+
"$",
158+
True,
155159
)
156-
# TODO why? how does it happen in the other places we handle streams
157-
if validated_fragment is None:
158-
validated_fragment = ""
159-
160-
if isinstance(validated_fragment, SkeletonReAsk):
161-
raise ValueError(
162-
"Received fragment schema is an invalid sub-schema "
163-
"of the expected output JSON schema."
160+
validators = self.validation_map["$"] or []
161+
# collect the result validated_chunk into validation progress
162+
# per validator
163+
for result in results:
164+
validator_log = result.validator_logs # type: ignore
165+
validator = next(
166+
filter(
167+
lambda x: x.rail_alias == validator_log.registered_name,
168+
validators,
169+
),
170+
None,
164171
)
172+
if (
173+
validator_log.validation_result
174+
and validator_log.validation_result.validated_chunk
175+
):
176+
is_filter = validator.on_fail_descriptor is OnFailAction.FILTER # type: ignore
177+
is_refrain = (
178+
validator.on_fail_descriptor is OnFailAction.REFRAIN # type: ignore
179+
)
180+
if validator_log.validation_result.outcome == "fail":
181+
validation_passed = False
182+
reasks, valid_op = self.introspect(
183+
validator_log.validation_result
184+
)
185+
if reasks:
186+
raise ValueError(
187+
"Reasks are not yet supported with streaming. Please "
188+
"remove reasks from schema or disable streaming."
189+
)
165190

166-
reasks, valid_op = self.introspect(validated_fragment)
167-
if reasks:
168-
raise ValueError(
169-
"Reasks are not yet supported with streaming. Please "
170-
"remove reasks from schema or disable streaming."
191+
if isinstance(validator_log.validation_result, PassResult):
192+
chunk = validator_log.validation_result.validated_chunk
193+
elif isinstance(validator_log.validation_result, FailResult):
194+
if is_filter or is_refrain:
195+
refrain_triggered = True
196+
chunk = ""
197+
else:
198+
chunk = validator_service.perform_correction(
199+
validator_log.validation_result,
200+
validator_log.validation_result.validated_chunk,
201+
validator, # type: ignore
202+
rechecked_value=None,
203+
) # type: ignore
204+
205+
if not hasattr(
206+
validation_progress, validator_log.validator_name
207+
):
208+
validation_progress[validator_log.validator_name] = ""
209+
210+
validation_progress[validator_log.validator_name] += chunk
211+
# if there is an entry for every validator
212+
# run a merge and emit a validation outcome
213+
if len(validation_progress) == len(validators):
214+
if refrain_triggered:
215+
current = ""
216+
else:
217+
merge_chunks = []
218+
for piece in validation_progress:
219+
merge_chunks.append(validation_progress[piece])
220+
221+
current = validator_service.multi_merge(fragment, merge_chunks)
222+
223+
vo = ValidationOutcome(
224+
call_id=call_log.id, # type: ignore
225+
raw_llm_output=fragment,
226+
validated_output=current,
227+
validation_passed=True,
171228
)
172-
validation_response += validated_fragment
173-
passed = call_log.status == pass_status
229+
fragment = ""
230+
validation_progress = {}
231+
refrain_triggered = False
232+
233+
yield vo
234+
235+
# if theres anything left merge and emit a chunk
236+
if len(validation_progress) > 0:
237+
merge_chunks = []
238+
for piece in validation_progress:
239+
merge_chunks.append(validation_progress[piece])
240+
241+
current = validator_service.multi_merge(fragment, merge_chunks)
174242
yield ValidationOutcome(
175243
call_id=call_log.id, # type: ignore
176-
raw_llm_output=chunk_text,
177-
validated_output=validated_fragment,
178-
validation_passed=passed,
244+
raw_llm_output=fragment,
245+
validated_output=current,
246+
validation_passed=validation_passed,
179247
)
180248
else:
181249
async for chunk in stream_output:

guardrails/validator_service/async_validator_service.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,38 @@ async def validate_child(
253253

254254
return value, metadata
255255

256+
async def async_partial_validate(
257+
self,
258+
value: Any,
259+
metadata: dict,
260+
validator_map: ValidatorMap,
261+
iteration: Iteration,
262+
absolute_path: str,
263+
reference_path: str,
264+
stream: Optional[bool] = False,
265+
**kwargs,
266+
) -> list[ValidatorRun]:
267+
# Then validate the parent value
268+
validators = validator_map.get(reference_path, [])
269+
coroutines: List[Coroutine[Any, Any, ValidatorRun]] = []
270+
271+
for validator in validators:
272+
coroutines.append(
273+
self.run_validator(
274+
iteration,
275+
validator,
276+
value,
277+
metadata,
278+
absolute_path,
279+
stream=stream,
280+
**kwargs,
281+
)
282+
)
283+
284+
results = await asyncio.gather(*coroutines)
285+
286+
return results
287+
256288
async def async_validate(
257289
self,
258290
value: Any,

guardrails/validator_service/sequential_validator_service.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
StreamValidationResult,
1111
ValidationResult,
1212
)
13-
from guardrails.merge import merge
1413
from guardrails.types import ValidatorMap, OnFailAction
1514
from guardrails.utils.exception_utils import UserFacingException
1615
from guardrails.classes.validation.validator_logs import ValidatorLogs
@@ -108,14 +107,6 @@ def run_validators_stream(
108107
**kwargs,
109108
)
110109

111-
# requires at least 2 validators
112-
def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]:
113-
current = new_values.pop()
114-
while len(new_values) > 0:
115-
nextval = new_values.pop()
116-
current = merge(current, nextval, original)
117-
return current
118-
119110
def run_validators_stream_fix(
120111
self,
121112
iteration: Iteration,

guardrails/validator_service/validator_service_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,14 @@ def run_validator(
167167
) -> ValidatorRun:
168168
raise NotImplementedError
169169

170+
# requires at least 2 validators
171+
def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]:
172+
current = new_values.pop()
173+
while len(new_values) > 0:
174+
nextval = new_values.pop()
175+
current = merge(current, nextval, original)
176+
return current
177+
170178
def merge_results(self, original_value: Any, new_values: list[Any]) -> Any:
171179
new_vals = deepcopy(new_values)
172180
current = new_values.pop()

tests/integration_tests/test_async_streaming.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@
2121
)
2222
from tests.integration_tests.test_assets.validators import LowerCase, MockDetectPII
2323

24+
POETRY_CHUNKS = [
25+
"John, under ",
26+
"GOLDEN bridges",
27+
", roams,\n",
28+
"SAN Francisco's ",
29+
"hills, his HOME.\n",
30+
"Dreams of",
31+
" FOG, and salty AIR,\n",
32+
"In his HEART",
33+
", he's always THERE.",
34+
]
35+
2436

2537
@register_validator(name="minsentencelength", data_type=["string", "list"])
2638
class MinSentenceLengthValidator(Validator):
@@ -131,21 +143,54 @@ async def gen():
131143
self.completion_stream = gen()
132144

133145

134-
POETRY_CHUNKS = [
135-
"John, under ",
136-
"GOLDEN bridges",
137-
", roams,\n",
138-
"SAN Francisco's ",
139-
"hills, his HOME.\n",
140-
"Dreams of",
141-
" FOG, and salty AIR,\n",
142-
"In his HEART",
143-
", he's always THERE.",
144-
]
146+
@pytest.mark.asyncio
147+
async def test_async_streaming_fix_behavior_two_validators(mocker):
148+
mocker.patch(
149+
"litellm.acompletion",
150+
return_value=Response(POETRY_CHUNKS),
151+
)
152+
153+
guard = gd.AsyncGuard().use_many(
154+
MockDetectPII(
155+
on_fail=OnFailAction.FIX,
156+
pii_entities="pii",
157+
replace_map={"John": "<PERSON>", "SAN Francisco's": "<LOCATION>"},
158+
),
159+
LowerCase(on_fail=OnFailAction.FIX),
160+
)
161+
prompt = """Write me a 4 line poem about John in San Francisco.
162+
Make every third word all caps."""
163+
gen = await guard(
164+
model="gpt-3.5-turbo",
165+
max_tokens=10,
166+
temperature=0,
167+
stream=True,
168+
prompt=prompt,
169+
)
170+
text = ""
171+
original = ""
172+
async for res in gen:
173+
original = original + res.raw_llm_output
174+
text = text + res.validated_output
175+
176+
assert (
177+
text
178+
== """<PERSON>, under golden bridges, roams,
179+
<LOCATION> hills, his home.
180+
dreams of fog, and salty air,
181+
in his heart, he's always there."""
182+
)
183+
assert (
184+
original
185+
== """John, under GOLDEN bridges, roams,
186+
SAN Francisco's hills, his HOME.
187+
Dreams of FOG, and salty AIR,
188+
In his HEART, he's always THERE."""
189+
)
145190

146191

147192
@pytest.mark.asyncio
148-
async def test_filter_behavior(mocker):
193+
async def test_async_streaming_filter_behavior(mocker):
149194
mocker.patch(
150195
"litellm.acompletion",
151196
return_value=Response(POETRY_CHUNKS),
@@ -169,16 +214,18 @@ async def test_filter_behavior(mocker):
169214
prompt=prompt,
170215
)
171216

172-
text = ""
173-
final_res = None
217+
validated = ""
218+
raw_llm_output = ""
219+
174220
async for res in gen:
175-
final_res = res
176-
text += res.validated_output
177-
178-
assert final_res.raw_llm_output == ", he's always THERE."
179-
# TODO deep dive this
180-
assert text == (
181-
"John, under GOLDEN bridges, roams,\n"
182-
"SAN Francisco's Dreams of FOG, and salty AIR,\n"
183-
"In his HEART"
221+
validated += res.validated_output
222+
raw_llm_output += res.raw_llm_output
223+
224+
assert validated == ""
225+
assert (
226+
raw_llm_output
227+
== """John, under GOLDEN bridges, roams,
228+
SAN Francisco's hills, his HOME.
229+
Dreams of FOG, and salty AIR,
230+
In his HEART, he's always THERE."""
184231
)

0 commit comments

Comments
 (0)