Skip to content

Commit e26316f

Browse files
committed
wip some updates and fixes for async streaming and telem
1 parent 0ff4f6f commit e26316f

File tree

9 files changed

+278
-11
lines changed

9 files changed

+278
-11
lines changed

guardrails/async_guard.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,11 @@ async def _exec(
369369
output=llm_output,
370370
base_model=self._base_model,
371371
full_schema_reask=full_schema_reask,
372-
disable_tracer=(not self._allow_metrics_collection),
372+
disable_tracer=(
373+
not self._allow_metrics_collection
374+
if isinstance(self._allow_metrics_collection, bool)
375+
else None
376+
),
373377
exec_options=self._exec_opts,
374378
)
375379
# Here we have an async generator
@@ -391,7 +395,11 @@ async def _exec(
391395
output=llm_output,
392396
base_model=self._base_model,
393397
full_schema_reask=full_schema_reask,
394-
disable_tracer=(not self._allow_metrics_collection),
398+
disable_tracer=(
399+
not self._allow_metrics_collection
400+
if isinstance(self._allow_metrics_collection, bool)
401+
else None
402+
),
395403
exec_options=self._exec_opts,
396404
)
397405
# Why are we using a different method here instead of just overriding?

guardrails/classes/rc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def load(cls, logger: Optional[logging.Logger] = None) -> "RC":
5353
value = to_bool(value)
5454

5555
config[key] = value
56-
56+
print("===== loaded config", config)
5757
rc_file.close()
5858

5959
# backfill no_metrics, handle defaults

guardrails/guard.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,11 @@ def _exec(
865865
output=llm_output,
866866
base_model=self._base_model,
867867
full_schema_reask=full_schema_reask,
868-
disable_tracer=(not self._allow_metrics_collection),
868+
disable_tracer=(
869+
not self._allow_metrics_collection
870+
if isinstance(self._allow_metrics_collection, bool)
871+
else None
872+
),
869873
exec_options=self._exec_opts,
870874
)
871875
return runner(call_log=call_log, prompt_params=prompt_params)
@@ -884,7 +888,11 @@ def _exec(
884888
output=llm_output,
885889
base_model=self._base_model,
886890
full_schema_reask=full_schema_reask,
887-
disable_tracer=(not self._allow_metrics_collection),
891+
disable_tracer=(
892+
not self._allow_metrics_collection
893+
if isinstance(self._allow_metrics_collection, bool)
894+
else None
895+
),
888896
exec_options=self._exec_opts,
889897
)
890898
call = runner(call_log=call_log, prompt_params=prompt_params)

guardrails/hub_telemetry/hub_tracing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def wrapper(*args, **kwargs):
253253
nonlocal origin
254254
origin = origin if origin is not None else name
255255
add_attributes(span, attrs, name, origin, *args, **kwargs)
256-
return _run_async_gen(fn, *args, **kwargs)
256+
return fn(*args, **kwargs)
257257
else:
258258
return fn(*args, **kwargs)
259259

guardrails/run/async_stream_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ async def async_step(
153153
validate_subschema=True,
154154
stream=True,
155155
)
156+
# TODO why? how does it happen in the other places we handle streams
157+
if validated_fragment is None:
158+
validated_fragment = ""
159+
156160
if isinstance(validated_fragment, SkeletonReAsk):
157161
raise ValueError(
158162
"Received fragment schema is an invalid sub-schema "
@@ -165,7 +169,7 @@ async def async_step(
165169
"Reasks are not yet supported with streaming. Please "
166170
"remove reasks from schema or disable streaming."
167171
)
168-
validation_response += cast(str, validated_fragment)
172+
validation_response += validated_fragment
169173
passed = call_log.status == pass_status
170174
yield ValidationOutcome(
171175
call_id=call_log.id, # type: ignore

guardrails/telemetry/guard_tracing.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212

1313
from opentelemetry import context, trace
14-
from opentelemetry.trace import StatusCode, Tracer, Span
14+
from opentelemetry.trace import StatusCode, Tracer, Span, Link, get_tracer
1515

1616
from guardrails.settings import settings
1717
from guardrails.classes.generic.stack import Stack
@@ -22,6 +22,10 @@
2222
from guardrails.telemetry.common import add_user_attributes
2323
from guardrails.version import GUARDRAILS_VERSION
2424

25+
import sys
26+
27+
if sys.version_info.minor < 10:
28+
from guardrails.utils.polyfills import anext
2529

2630
# from sentence_transformers import SentenceTransformer
2731
# import numpy as np
@@ -195,8 +199,17 @@ async def trace_async_stream_guard(
195199
while next_exists:
196200
try:
197201
res = await anext(result) # type: ignore
198-
add_guard_attributes(guard_span, history, res)
199-
add_user_attributes(guard_span)
202+
if not guard_span.is_recording():
203+
# Assuming you have a tracer instance
204+
tracer = get_tracer(__name__)
205+
# Create a new span and link it to the previous span
206+
with tracer.start_as_current_span(
207+
"new_guard_span", links=[Link(guard_span.get_span_context())]
208+
) as new_span:
209+
guard_span = new_span
210+
211+
add_guard_attributes(guard_span, history, res)
212+
add_user_attributes(guard_span)
200213
yield res
201214
except StopIteration:
202215
next_exists = False

guardrails/telemetry/runner_tracing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from guardrails.utils.safe_get import safe_get
2222
from guardrails.version import GUARDRAILS_VERSION
2323

24+
import sys
25+
26+
if sys.version_info.minor < 10:
27+
from guardrails.utils.polyfills import anext
2428

2529
#########################################
2630
### START Runner.step Instrumentation ###

guardrails/utils/hub_telemetry_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def initialize_tracer(
5757
"""Initializes a tracer for Guardrails Hub."""
5858
if enabled is None:
5959
enabled = settings.rc.enable_metrics or False
60-
6160
self._enabled = enabled
6261
self._carrier = {}
6362
self._service_name = service_name
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# 3 tests
2+
# 1. Test streaming with OpenAICallable (mock openai.Completion.create)
3+
# 2. Test streaming with OpenAIChatCallable (mock openai.ChatCompletion.create)
4+
# 3. Test string schema streaming
5+
# Using the LowerCase Validator, and a custom validator to show new streaming behavior
6+
from typing import Any, Callable, Dict, List, Optional, Union
7+
8+
import asyncio
9+
import pytest
10+
from pydantic import BaseModel, Field
11+
12+
import guardrails as gd
13+
from guardrails.utils.casting_utils import to_int
14+
from guardrails.validator_base import (
15+
ErrorSpan,
16+
FailResult,
17+
OnFailAction,
18+
PassResult,
19+
ValidationResult,
20+
Validator,
21+
register_validator,
22+
)
23+
from tests.integration_tests.test_assets.validators import LowerCase, MockDetectPII
24+
25+
expected_raw_output = {"statement": "I am DOING well, and I HOPE you aRe too."}
26+
expected_fix_output = {"statement": "i am doing well, and i hope you are too."}
27+
expected_noop_output = {"statement": "I am DOING well, and I HOPE you aRe too."}
28+
expected_filter_refrain_output = {}
29+
30+
31+
@register_validator(name="minsentencelength", data_type=["string", "list"])
32+
class MinSentenceLengthValidator(Validator):
33+
def __init__(
34+
self,
35+
min: Optional[int] = None,
36+
max: Optional[int] = None,
37+
on_fail: Optional[Callable] = None,
38+
):
39+
super().__init__(
40+
on_fail=on_fail,
41+
min=min,
42+
max=max,
43+
)
44+
self._min = to_int(min)
45+
self._max = to_int(max)
46+
47+
def sentence_split(self, value):
48+
return list(map(lambda x: x + ".", value.split(".")[:-1]))
49+
50+
def validate(self, value: Union[str, List], metadata: Dict) -> ValidationResult:
51+
sentences = self.sentence_split(value)
52+
error_spans = []
53+
index = 0
54+
for sentence in sentences:
55+
if len(sentence) < self._min:
56+
error_spans.append(
57+
ErrorSpan(
58+
start=index,
59+
end=index + len(sentence),
60+
reason=f"Sentence has length less than {self._min}. "
61+
f"Please return a longer output, "
62+
f"that is shorter than {self._max} characters.",
63+
)
64+
)
65+
if len(sentence) > self._max:
66+
error_spans.append(
67+
ErrorSpan(
68+
start=index,
69+
end=index + len(sentence),
70+
reason=f"Sentence has length greater than {self._max}. "
71+
f"Please return a shorter output, "
72+
f"that is shorter than {self._max} characters.",
73+
)
74+
)
75+
index = index + len(sentence)
76+
if len(error_spans) > 0:
77+
return FailResult(
78+
validated_chunk=value,
79+
error_spans=error_spans,
80+
error_message=f"Sentence has length less than {self._min}. "
81+
f"Please return a longer output, "
82+
f"that is shorter than {self._max} characters.",
83+
)
84+
return PassResult(validated_chunk=value)
85+
86+
def validate_stream(self, chunk: Any, metadata: Dict, **kwargs) -> ValidationResult:
87+
return super().validate_stream(chunk, metadata, **kwargs)
88+
89+
90+
class Delta:
91+
content: str
92+
93+
def __init__(self, content):
94+
self.content = content
95+
96+
97+
class Choice:
98+
text: str
99+
finish_reason: str
100+
index: int
101+
delta: Delta
102+
103+
def __init__(self, text, delta, finish_reason, index=0):
104+
self.index = index
105+
self.delta = delta
106+
self.text = text
107+
self.finish_reason = finish_reason
108+
109+
110+
class MockOpenAIV1ChunkResponse:
111+
choices: list
112+
model: str
113+
114+
def __init__(self, choices, model):
115+
self.choices = choices
116+
self.model = model
117+
118+
119+
class Response:
120+
def __init__(self, chunks):
121+
self.chunks = chunks
122+
123+
async def gen():
124+
for chunk in self.chunks:
125+
yield MockOpenAIV1ChunkResponse(
126+
choices=[
127+
Choice(
128+
delta=Delta(content=chunk),
129+
text=chunk,
130+
finish_reason=None,
131+
)
132+
],
133+
model="OpenAI model name",
134+
)
135+
await asyncio.sleep(0) # Yield control to the event loop
136+
137+
self.completion_stream = gen()
138+
139+
140+
class LowerCaseFix(BaseModel):
141+
statement: str = Field(
142+
description="Validates whether the text is in lower case.",
143+
validators=[LowerCase(on_fail=OnFailAction.FIX)],
144+
)
145+
146+
147+
class LowerCaseNoop(BaseModel):
148+
statement: str = Field(
149+
description="Validates whether the text is in lower case.",
150+
validators=[LowerCase(on_fail=OnFailAction.NOOP)],
151+
)
152+
153+
154+
class LowerCaseFilter(BaseModel):
155+
statement: str = Field(
156+
description="Validates whether the text is in lower case.",
157+
validators=[LowerCase(on_fail=OnFailAction.FILTER)],
158+
)
159+
160+
161+
class LowerCaseRefrain(BaseModel):
162+
statement: str = Field(
163+
description="Validates whether the text is in lower case.",
164+
validators=[LowerCase(on_fail=OnFailAction.REFRAIN)],
165+
)
166+
167+
168+
expected_minsentence_noop_output = ""
169+
170+
171+
class MinSentenceLengthNoOp(BaseModel):
172+
statement: str = Field(
173+
description="Validates whether the text is in lower case.",
174+
validators=[MinSentenceLengthValidator(on_fail=OnFailAction.NOOP)],
175+
)
176+
177+
178+
STR_PROMPT = "Say something nice to me."
179+
180+
PROMPT = """
181+
Say something nice to me.
182+
183+
${gr.complete_json_suffix}
184+
"""
185+
186+
POETRY_CHUNKS = [
187+
'"John, under ',
188+
"GOLDEN bridges",
189+
", roams,\n",
190+
"SAN Francisco's ",
191+
"hills, his HOME.\n",
192+
"Dreams of",
193+
" FOG, and salty AIR,\n",
194+
"In his HEART",
195+
", he's always THERE.",
196+
]
197+
198+
199+
@pytest.mark.asyncio
200+
async def test_filter_behavior(mocker):
201+
mocker.patch(
202+
"litellm.acompletion",
203+
return_value=Response(POETRY_CHUNKS),
204+
)
205+
206+
guard = gd.AsyncGuard().use_many(
207+
MockDetectPII(
208+
on_fail=OnFailAction.FIX,
209+
pii_entities="pii",
210+
replace_map={"John": "<PERSON>", "SAN Francisco's": "<LOCATION>"},
211+
),
212+
LowerCase(on_fail=OnFailAction.FILTER),
213+
)
214+
prompt = """Write me a 4 line poem about John in San Francisco.
215+
Make every third word all caps."""
216+
gen = await guard(
217+
model="gpt-3.5-turbo",
218+
max_tokens=10,
219+
temperature=0,
220+
stream=True,
221+
prompt=prompt,
222+
)
223+
224+
text = ""
225+
final_res = None
226+
async for res in gen:
227+
final_res = res
228+
text = text + res.validated_output
229+
230+
assert final_res.raw_llm_output == ", he's always THERE."
231+
assert text == ""

0 commit comments

Comments
 (0)