Skip to content

Commit 618ad21

Browse files
authored
Merge pull request #1100 from guardrails-ai/async_streaming_telem_updates_and_fixes
Some updates and fixes for async streaming and telemetry
2 parents b0e1da1 + 70d4f8c commit 618ad21

File tree

10 files changed

+253
-15
lines changed

10 files changed

+253
-15
lines changed

guardrails/async_guard.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,11 @@ async def _exec(
369369
output=llm_output,
370370
base_model=self._base_model,
371371
full_schema_reask=full_schema_reask,
372-
disable_tracer=(not self._allow_metrics_collection),
372+
disable_tracer=(
373+
not self._allow_metrics_collection
374+
if isinstance(self._allow_metrics_collection, bool)
375+
else None
376+
),
373377
exec_options=self._exec_opts,
374378
)
375379
# Here we have an async generator
@@ -391,7 +395,11 @@ async def _exec(
391395
output=llm_output,
392396
base_model=self._base_model,
393397
full_schema_reask=full_schema_reask,
394-
disable_tracer=(not self._allow_metrics_collection),
398+
disable_tracer=(
399+
not self._allow_metrics_collection
400+
if isinstance(self._allow_metrics_collection, bool)
401+
else None
402+
),
395403
exec_options=self._exec_opts,
396404
)
397405
# Why are we using a different method here instead of just overriding?

guardrails/cli/hub/install.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,14 @@
1010
from guardrails.cli.version import version_warnings_if_applicable
1111

1212

13-
@hub_command.command()
13+
# Quick note: This is the command for `guardrails hub install`. We change the name of
14+
# the function def to prevent confusion, lest people import it directly and calling it
15+
# with a string for package_uris instead of a list, which behaves oddly. If you need to
16+
# call install from a script, please consider importing install from guardrails,
17+
# not guardrails.cli.hub.install.
18+
@hub_command.command(name="install")
1419
@trace(name="guardrails-cli/hub/install")
15-
def install(
20+
def install_cli(
1621
package_uris: List[str] = typer.Argument(
1722
...,
1823
help="URIs to the packages to install. Example: hub://guardrails/regex_match hub://guardrails/toxic_language",
@@ -33,6 +38,17 @@ def install(
3338
),
3439
):
3540
try:
41+
if isinstance(package_uris, str):
42+
logger.error(
43+
f"`install` in {__file__} was called with a string instead of "
44+
"a list! This can happen if it is invoked directly instead of "
45+
"being run via the CLI. Did you mean to import `from guardrails import "
46+
"install` instead? Recovering..."
47+
)
48+
package_uris = [
49+
package_uris,
50+
]
51+
3652
from guardrails.hub.install import install_multiple
3753

3854
def confirm():

guardrails/guard.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,11 @@ def _exec(
865865
output=llm_output,
866866
base_model=self._base_model,
867867
full_schema_reask=full_schema_reask,
868-
disable_tracer=(not self._allow_metrics_collection),
868+
disable_tracer=(
869+
not self._allow_metrics_collection
870+
if isinstance(self._allow_metrics_collection, bool)
871+
else None
872+
),
869873
exec_options=self._exec_opts,
870874
)
871875
return runner(call_log=call_log, prompt_params=prompt_params)
@@ -884,7 +888,11 @@ def _exec(
884888
output=llm_output,
885889
base_model=self._base_model,
886890
full_schema_reask=full_schema_reask,
887-
disable_tracer=(not self._allow_metrics_collection),
891+
disable_tracer=(
892+
not self._allow_metrics_collection
893+
if isinstance(self._allow_metrics_collection, bool)
894+
else None
895+
),
888896
exec_options=self._exec_opts,
889897
)
890898
call = runner(call_log=call_log, prompt_params=prompt_params)

guardrails/hub/install.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def install(
5353
Examples:
5454
>>> RegexMatch = install("hub://guardrails/regex_match").RegexMatch
5555
56-
>>> install("hub://guardrails/regex_match);
56+
>>> install("hub://guardrails/regex_match")
5757
>>> import guardrails.hub.regex_match as regex_match
5858
"""
5959

guardrails/hub_telemetry/hub_tracing.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Any,
44
Dict,
55
Optional,
6+
AsyncGenerator,
67
)
78

89
from opentelemetry.trace import Span
@@ -224,7 +225,7 @@ def wrapper(*args, **kwargs):
224225
return decorator
225226

226227

227-
async def _run_async_gen(fn, *args, **kwargs):
228+
async def _run_async_gen(fn, *args, **kwargs) -> AsyncGenerator[Any, None]:
228229
gen = fn(*args, **kwargs)
229230
async for item in gen:
230231
yield item
@@ -238,7 +239,7 @@ def async_trace_stream(
238239
):
239240
def decorator(fn):
240241
@wraps(fn)
241-
async def wrapper(*args, **kwargs):
242+
def wrapper(*args, **kwargs):
242243
hub_telemetry = HubTelemetry()
243244
if hub_telemetry._enabled and hub_telemetry._tracer is not None:
244245
with hub_telemetry._tracer.start_span(
@@ -252,7 +253,7 @@ async def wrapper(*args, **kwargs):
252253
nonlocal origin
253254
origin = origin if origin is not None else name
254255
add_attributes(span, attrs, name, origin, *args, **kwargs)
255-
return _run_async_gen(fn, *args, **kwargs)
256+
return fn(*args, **kwargs)
256257
else:
257258
return fn(*args, **kwargs)
258259

guardrails/run/async_stream_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ async def async_step(
153153
validate_subschema=True,
154154
stream=True,
155155
)
156+
# TODO why? how does it happen in the other places we handle streams
157+
if validated_fragment is None:
158+
validated_fragment = ""
159+
156160
if isinstance(validated_fragment, SkeletonReAsk):
157161
raise ValueError(
158162
"Received fragment schema is an invalid sub-schema "
@@ -165,7 +169,7 @@ async def async_step(
165169
"Reasks are not yet supported with streaming. Please "
166170
"remove reasks from schema or disable streaming."
167171
)
168-
validation_response += cast(str, validated_fragment)
172+
validation_response += validated_fragment
169173
passed = call_log.status == pass_status
170174
yield ValidationOutcome(
171175
call_id=call_log.id, # type: ignore

guardrails/telemetry/guard_tracing.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212

1313
from opentelemetry import context, trace
14-
from opentelemetry.trace import StatusCode, Tracer, Span
14+
from opentelemetry.trace import StatusCode, Tracer, Span, Link, get_tracer
1515

1616
from guardrails.settings import settings
1717
from guardrails.classes.generic.stack import Stack
@@ -22,6 +22,10 @@
2222
from guardrails.telemetry.common import add_user_attributes
2323
from guardrails.version import GUARDRAILS_VERSION
2424

25+
import sys
26+
27+
if sys.version_info.minor < 10:
28+
from guardrails.utils.polyfills import anext
2529

2630
# from sentence_transformers import SentenceTransformer
2731
# import numpy as np
@@ -195,8 +199,18 @@ async def trace_async_stream_guard(
195199
while next_exists:
196200
try:
197201
res = await anext(result) # type: ignore
198-
add_guard_attributes(guard_span, history, res)
199-
add_user_attributes(guard_span)
202+
if not guard_span.is_recording():
203+
# Assuming you have a tracer instance
204+
tracer = get_tracer(__name__)
205+
# Create a new span and link it to the previous span
206+
with tracer.start_as_current_span(
207+
"new_guard_span", # type: ignore
208+
links=[Link(guard_span.get_span_context())],
209+
) as new_span:
210+
guard_span = new_span
211+
212+
add_guard_attributes(guard_span, history, res)
213+
add_user_attributes(guard_span)
200214
yield res
201215
except StopIteration:
202216
next_exists = False

guardrails/telemetry/runner_tracing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from guardrails.utils.safe_get import safe_get
2222
from guardrails.version import GUARDRAILS_VERSION
2323

24+
import sys
25+
26+
if sys.version_info.minor < 10:
27+
from guardrails.utils.polyfills import anext
2428

2529
#########################################
2630
### START Runner.step Instrumentation ###

guardrails/utils/hub_telemetry_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def initialize_tracer(
5757
"""Initializes a tracer for Guardrails Hub."""
5858
if enabled is None:
5959
enabled = settings.rc.enable_metrics or False
60-
6160
self._enabled = enabled
6261
self._carrier = {}
6362
self._service_name = service_name
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# 3 tests
2+
# 1. Test streaming with OpenAICallable (mock openai.Completion.create)
3+
# 2. Test streaming with OpenAIChatCallable (mock openai.ChatCompletion.create)
4+
# 3. Test string schema streaming
5+
# Using the LowerCase Validator, and a custom validator to show new streaming behavior
6+
from typing import Any, Callable, Dict, List, Optional, Union
7+
8+
import asyncio
9+
import pytest
10+
11+
import guardrails as gd
12+
from guardrails.utils.casting_utils import to_int
13+
from guardrails.validator_base import (
14+
ErrorSpan,
15+
FailResult,
16+
OnFailAction,
17+
PassResult,
18+
ValidationResult,
19+
Validator,
20+
register_validator,
21+
)
22+
from tests.integration_tests.test_assets.validators import LowerCase, MockDetectPII
23+
24+
25+
@register_validator(name="minsentencelength", data_type=["string", "list"])
26+
class MinSentenceLengthValidator(Validator):
27+
def __init__(
28+
self,
29+
min: Optional[int] = None,
30+
max: Optional[int] = None,
31+
on_fail: Optional[Callable] = None,
32+
):
33+
super().__init__(
34+
on_fail=on_fail,
35+
min=min,
36+
max=max,
37+
)
38+
self._min = to_int(min)
39+
self._max = to_int(max)
40+
41+
def sentence_split(self, value):
42+
return list(map(lambda x: x + ".", value.split(".")[:-1]))
43+
44+
def validate(self, value: Union[str, List], metadata: Dict) -> ValidationResult:
45+
sentences = self.sentence_split(value)
46+
error_spans = []
47+
index = 0
48+
for sentence in sentences:
49+
if len(sentence) < self._min:
50+
error_spans.append(
51+
ErrorSpan(
52+
start=index,
53+
end=index + len(sentence),
54+
reason=f"Sentence has length less than {self._min}. "
55+
f"Please return a longer output, "
56+
f"that is shorter than {self._max} characters.",
57+
)
58+
)
59+
if len(sentence) > self._max:
60+
error_spans.append(
61+
ErrorSpan(
62+
start=index,
63+
end=index + len(sentence),
64+
reason=f"Sentence has length greater than {self._max}. "
65+
f"Please return a shorter output, "
66+
f"that is shorter than {self._max} characters.",
67+
)
68+
)
69+
index = index + len(sentence)
70+
if len(error_spans) > 0:
71+
return FailResult(
72+
validated_chunk=value,
73+
error_spans=error_spans,
74+
error_message=f"Sentence has length less than {self._min}. "
75+
f"Please return a longer output, "
76+
f"that is shorter than {self._max} characters.",
77+
)
78+
return PassResult(validated_chunk=value)
79+
80+
def validate_stream(self, chunk: Any, metadata: Dict, **kwargs) -> ValidationResult:
81+
return super().validate_stream(chunk, metadata, **kwargs)
82+
83+
84+
class Delta:
85+
content: str
86+
87+
def __init__(self, content):
88+
self.content = content
89+
90+
91+
class Choice:
92+
text: str
93+
finish_reason: str
94+
index: int
95+
delta: Delta
96+
97+
def __init__(self, text, delta, finish_reason, index=0):
98+
self.index = index
99+
self.delta = delta
100+
self.text = text
101+
self.finish_reason = finish_reason
102+
103+
104+
class MockOpenAIV1ChunkResponse:
105+
choices: list
106+
model: str
107+
108+
def __init__(self, choices, model):
109+
self.choices = choices
110+
self.model = model
111+
112+
113+
class Response:
114+
def __init__(self, chunks):
115+
self.chunks = chunks
116+
117+
async def gen():
118+
for chunk in self.chunks:
119+
yield MockOpenAIV1ChunkResponse(
120+
choices=[
121+
Choice(
122+
delta=Delta(content=chunk),
123+
text=chunk,
124+
finish_reason=None,
125+
)
126+
],
127+
model="OpenAI model name",
128+
)
129+
await asyncio.sleep(0) # Yield control to the event loop
130+
131+
self.completion_stream = gen()
132+
133+
134+
POETRY_CHUNKS = [
135+
"John, under ",
136+
"GOLDEN bridges",
137+
", roams,\n",
138+
"SAN Francisco's ",
139+
"hills, his HOME.\n",
140+
"Dreams of",
141+
" FOG, and salty AIR,\n",
142+
"In his HEART",
143+
", he's always THERE.",
144+
]
145+
146+
147+
@pytest.mark.asyncio
148+
async def test_filter_behavior(mocker):
149+
mocker.patch(
150+
"litellm.acompletion",
151+
return_value=Response(POETRY_CHUNKS),
152+
)
153+
154+
guard = gd.AsyncGuard().use_many(
155+
MockDetectPII(
156+
on_fail=OnFailAction.FIX,
157+
pii_entities="pii",
158+
replace_map={"John": "<PERSON>", "SAN Francisco's": "<LOCATION>"},
159+
),
160+
LowerCase(on_fail=OnFailAction.FILTER),
161+
)
162+
prompt = """Write me a 4 line poem about John in San Francisco.
163+
Make every third word all caps."""
164+
gen = await guard(
165+
model="gpt-3.5-turbo",
166+
max_tokens=10,
167+
temperature=0,
168+
stream=True,
169+
prompt=prompt,
170+
)
171+
172+
text = ""
173+
final_res = None
174+
async for res in gen:
175+
final_res = res
176+
text += res.validated_output
177+
178+
assert final_res.raw_llm_output == ", he's always THERE."
179+
# TODO deep dive this
180+
assert text == (
181+
"John, under GOLDEN bridges, roams,\n"
182+
"SAN Francisco's Dreams of FOG, and salty AIR,\n"
183+
"In his HEART"
184+
)

0 commit comments

Comments
 (0)