Skip to content

Commit 10a1dde

Browse files
committed
Merge remote-tracking branch 'origin/main' into feature/litellm_cleanup
2 parents ef2863b + 80fc455 commit 10a1dde

File tree

13 files changed

+131
-38
lines changed

13 files changed

+131
-38
lines changed

guardrails/api_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def stream_validate(
104104
)
105105
if line:
106106
json_output = json.loads(line)
107+
if json_output.get("error"):
108+
raise Exception(json_output.get("error").get("message"))
107109
yield IValidationOutcome.from_dict(json_output)
108110

109111
def get_history(self, guard_name: str, call_id: str):

guardrails/async_guard.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -529,13 +529,15 @@ async def _stream_server_call(
529529
validated_output=validated_output,
530530
validation_passed=(validation_output.validation_passed is True),
531531
)
532-
if validation_output:
533-
guard_history = self._api_client.get_history(
534-
self.name, validation_output.call_id
535-
)
536-
self.history.extend(
537-
[Call.from_interface(call) for call in guard_history]
538-
)
532+
# TODO re-enable this once we have a way to get history
533+
# from a multi-node server
534+
# if validation_output:
535+
# guard_history = self._api_client.get_history(
536+
# self.name, validation_output.call_id
537+
# )
538+
# self.history.extend(
539+
# [Call.from_interface(call) for call in guard_history]
540+
# )
539541
else:
540542
raise ValueError("AsyncGuard does not have an api client!")
541543

guardrails/classes/llm/llm_response.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@ def to_interface(self) -> ILLMResponse:
4747
stream_output = [str(so) for so in copy_2]
4848

4949
async_stream_output = None
50-
if self.async_stream_output:
50+
# dont do this again if already aiter-able were updating
51+
# ourselves here so in memory
52+
# this can cause issues
53+
if self.async_stream_output and not hasattr(
54+
self.async_stream_output, "__aiter__"
55+
):
5156
# tee doesn't work with async iterators
5257
# This may be destructive
5358
async_stream_output = []

guardrails/cli/create.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def generate_template_config(
118118
guard_instantiations = []
119119

120120
for i, guard in enumerate(template["guards"]):
121-
guard_instantiations.append(f"guard{i} = Guard.from_dict(guards[{i}])")
121+
guard_instantiations.append(f"guard{i} = AsyncGuard.from_dict(guards[{i}])")
122122
guard_instantiations = "\n".join(guard_instantiations)
123123
# Interpolate variables
124124
output_content = template_content.format(

guardrails/cli/hub/template_config.py.template

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import os
3-
from guardrails import Guard
3+
from guardrails import AsyncGuard, Guard
44
from guardrails.hub import {VALIDATOR_IMPORTS}
55

66
try:

guardrails/guard.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -896,11 +896,12 @@ def __call__(
896896

897897
messages = messages or self._exec_opts.messages or []
898898

899-
# if messages is not None and not len(messages):
900-
# raise RuntimeError(
901-
# "You must provide messages. "
902-
# "Alternatively, you can provide a prompt in the Schema constructor."
903-
# )
899+
if messages is not None and not len(messages):
900+
raise RuntimeError(
901+
"You must provide messages. "
902+
"Alternatively, you can provide messages in the Schema constructor."
903+
)
904+
904905
return trace_guard_execution(
905906
self.name,
906907
self.history,
@@ -1113,10 +1114,12 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O
11131114
error="The response from the server was empty!",
11141115
)
11151116

1116-
guard_history = self._api_client.get_history(
1117-
self.name, validation_output.call_id
1118-
)
1119-
self.history.extend([Call.from_interface(call) for call in guard_history])
1117+
# TODO reenable this when we have history support in
1118+
# multi-node server environments
1119+
# guard_history = self._api_client.get_history(
1120+
# self.name, validation_output.call_id
1121+
# )
1122+
# self.history.extend([Call.from_interface(call) for call in guard_history])
11201123

11211124
validation_summaries = []
11221125
if self.history.last and self.history.last.iterations.last:
@@ -1179,13 +1182,15 @@ def _stream_server_call(
11791182
validated_output=validated_output,
11801183
validation_passed=(validation_output.validation_passed is True),
11811184
)
1182-
if validation_output:
1183-
guard_history = self._api_client.get_history(
1184-
self.name, validation_output.call_id
1185-
)
1186-
self.history.extend(
1187-
[Call.from_interface(call) for call in guard_history]
1188-
)
1185+
1186+
# TODO reenable this when sever supports multi-node history
1187+
# if validation_output:
1188+
# guard_history = self._api_client.get_history(
1189+
# self.name, validation_output.call_id
1190+
# )
1191+
# self.history.extend(
1192+
# [Call.from_interface(call) for call in guard_history]
1193+
# )
11891194
else:
11901195
raise ValueError("Guard does not have an api client!")
11911196

guardrails/llm_providers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,11 @@ def _invoke_llm(
217217
"function_call", safe_get(function_calling_tools, 0)
218218
),
219219
)
220-
kwargs.pop("instructions", None)
221-
kwargs.pop("prompt", None)
220+
221+
# these are gr only and should not be getting passed to llms
222+
kwargs.pop("reask_prompt", None)
223+
kwargs.pop("reask_instructions", None)
224+
222225
response = completion(
223226
model=model,
224227
*args,
@@ -666,6 +669,10 @@ async def invoke_llm(
666669
),
667670
)
668671

672+
# these are gr only and should not be getting passed to llms
673+
kwargs.pop("reask_prompt", None)
674+
kwargs.pop("reask_instructions", None)
675+
669676
response = await acompletion(
670677
*args,
671678
**kwargs,

guardrails/run/async_stream_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ async def async_step(
125125
_ = self.is_last_chunk(chunk, api)
126126

127127
fragment += chunk_text
128+
128129
results = await validator_service.async_partial_validate(
129130
chunk_text,
130131
self.metadata,
@@ -134,7 +135,8 @@ async def async_step(
134135
"$",
135136
True,
136137
)
137-
validators = self.validation_map["$"] or []
138+
validators = self.validation_map.get("$", [])
139+
138140
# collect the result validated_chunk into validation progress
139141
# per validator
140142
for result in results:
@@ -187,7 +189,7 @@ async def async_step(
187189
validation_progress[validator_log.validator_name] += chunk
188190
# if there is an entry for every validator
189191
# run a merge and emit a validation outcome
190-
if len(validation_progress) == len(validators):
192+
if len(validation_progress) == len(validators) or len(validators) == 0:
191193
if refrain_triggered:
192194
current = ""
193195
else:

guardrails/telemetry/guard_tracing.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,18 @@ def trace_stream_guard(
142142
res = next(result) # type: ignore
143143
# FIXME: This should only be called once;
144144
# Accumulate the validated output and call at the end
145-
add_guard_attributes(guard_span, history, res)
146-
add_user_attributes(guard_span)
147-
yield res
145+
if not guard_span.is_recording():
146+
# Assuming you have a tracer instance
147+
tracer = get_tracer(__name__)
148+
# Create a new span and link it to the previous span
149+
with tracer.start_as_current_span(
150+
"stream_guard_span", # type: ignore
151+
links=[Link(guard_span.get_span_context())],
152+
) as new_span:
153+
guard_span = new_span
154+
add_guard_attributes(guard_span, history, res)
155+
add_user_attributes(guard_span)
156+
yield res
148157
except StopIteration:
149158
next_exists = False
150159

@@ -177,6 +186,7 @@ def trace_guard_execution(
177186
result, ValidationOutcome
178187
):
179188
return trace_stream_guard(guard_span, result, history)
189+
180190
add_guard_attributes(guard_span, history, result)
181191
add_user_attributes(guard_span)
182192
return result
@@ -201,14 +211,14 @@ async def trace_async_stream_guard(
201211
tracer = get_tracer(__name__)
202212
# Create a new span and link it to the previous span
203213
with tracer.start_as_current_span(
204-
"new_guard_span", # type: ignore
214+
"async_stream_span", # type: ignore
205215
links=[Link(guard_span.get_span_context())],
206216
) as new_span:
207217
guard_span = new_span
208218

209219
add_guard_attributes(guard_span, history, res)
210220
add_user_attributes(guard_span)
211-
yield res
221+
yield res
212222
except StopIteration:
213223
next_exists = False
214224
except StopAsyncIteration:

guardrails/telemetry/runner_tracing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,11 @@ def trace_call_wrapper(*args, **kwargs):
265265
) as call_span:
266266
try:
267267
response = fn(*args, **kwargs)
268+
if isinstance(response, LLMResponse) and (
269+
response.async_stream_output or response.stream_output
270+
):
271+
# TODO: Iterate, add a call attr each time
272+
return response
268273
add_call_attributes(call_span, response, *args, **kwargs)
269274
return response
270275
except Exception as e:

0 commit comments

Comments
 (0)