Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions guardrails/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def stream_validate(
)
if line:
json_output = json.loads(line)
if json_output.get("error"):
raise Exception(json_output.get("error").get("message"))
yield IValidationOutcome.from_dict(json_output)

def get_history(self, guard_name: str, call_id: str):
Expand Down
16 changes: 9 additions & 7 deletions guardrails/async_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,13 +574,15 @@ async def _stream_server_call(
validated_output=validated_output,
validation_passed=(validation_output.validation_passed is True),
)
if validation_output:
guard_history = self._api_client.get_history(
self.name, validation_output.call_id
)
self.history.extend(
[Call.from_interface(call) for call in guard_history]
)
# TODO re-enable this once we have a way to get history
# from a multi-node server
# if validation_output:
# guard_history = self._api_client.get_history(
# self.name, validation_output.call_id
# )
# self.history.extend(
# [Call.from_interface(call) for call in guard_history]
# )
else:
raise ValueError("AsyncGuard does not have an api client!")

Expand Down
7 changes: 6 additions & 1 deletion guardrails/classes/llm/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def to_interface(self) -> ILLMResponse:
stream_output = [str(so) for so in copy_2]

async_stream_output = None
if self.async_stream_output:
# dont do this again if already aiter-able were updating
# ourselves here so in memory
# this can cause issues
if self.async_stream_output and not hasattr(
self.async_stream_output, "__aiter__"
):
# tee doesn't work with async iterators
# This may be destructive
async_stream_output = []
Expand Down
26 changes: 15 additions & 11 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,10 +1215,12 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O
error="The response from the server was empty!",
)

guard_history = self._api_client.get_history(
self.name, validation_output.call_id
)
self.history.extend([Call.from_interface(call) for call in guard_history])
# TODO renable this when we have history support in
# multi-node server environments
# guard_history = self._api_client.get_history(
# self.name, validation_output.call_id
# )
# self.history.extend([Call.from_interface(call) for call in guard_history])

validation_summaries = []
if self.history.last and self.history.last.iterations.last:
Expand Down Expand Up @@ -1281,13 +1283,15 @@ def _stream_server_call(
validated_output=validated_output,
validation_passed=(validation_output.validation_passed is True),
)
if validation_output:
guard_history = self._api_client.get_history(
self.name, validation_output.call_id
)
self.history.extend(
[Call.from_interface(call) for call in guard_history]
)

# TODO reenable this when sever supports multi-node history
# if validation_output:
# guard_history = self._api_client.get_history(
# self.name, validation_output.call_id
# )
# self.history.extend(
# [Call.from_interface(call) for call in guard_history]
# )
else:
raise ValueError("Guard does not have an api client!")

Expand Down
8 changes: 8 additions & 0 deletions guardrails/llm_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,10 @@ def _invoke_llm(
),
)

# these are gr only and should not be getting passed to llms
kwargs.pop("reask_prompt", None)
kwargs.pop("reask_instructions", None)

response = completion(
model=model,
*args,
Expand Down Expand Up @@ -1088,6 +1092,10 @@ async def invoke_llm(
),
)

# these are gr only and should not be getting passed to llms
kwargs.pop("reask_prompt", None)
kwargs.pop("reask_instructions", None)

response = await acompletion(
*args,
**kwargs,
Expand Down
6 changes: 4 additions & 2 deletions guardrails/run/async_stream_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ async def async_step(
_ = self.is_last_chunk(chunk, api)

fragment += chunk_text

results = await validator_service.async_partial_validate(
chunk_text,
self.metadata,
Expand All @@ -157,7 +158,8 @@ async def async_step(
"$",
True,
)
validators = self.validation_map["$"] or []
validators = self.validation_map.get("$", [])

# collect the result validated_chunk into validation progress
# per validator
for result in results:
Expand Down Expand Up @@ -210,7 +212,7 @@ async def async_step(
validation_progress[validator_log.validator_name] += chunk
# if there is an entry for every validator
# run a merge and emit a validation outcome
if len(validation_progress) == len(validators):
if len(validation_progress) == len(validators) or len(validators) == 0:
if refrain_triggered:
current = ""
else:
Expand Down
20 changes: 15 additions & 5 deletions guardrails/telemetry/guard_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,18 @@ def trace_stream_guard(
res = next(result) # type: ignore
# FIXME: This should only be called once;
# Accumulate the validated output and call at the end
add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
yield res
if not guard_span.is_recording():
# Assuming you have a tracer instance
tracer = get_tracer(__name__)
# Create a new span and link it to the previous span
with tracer.start_as_current_span(
"stream_guard_span", # type: ignore
links=[Link(guard_span.get_span_context())],
) as new_span:
guard_span = new_span
add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
yield res
except StopIteration:
next_exists = False

Expand Down Expand Up @@ -180,6 +189,7 @@ def trace_guard_execution(
result, ValidationOutcome
):
return trace_stream_guard(guard_span, result, history)

add_guard_attributes(guard_span, history, result)
add_user_attributes(guard_span)
return result
Expand All @@ -204,14 +214,14 @@ async def trace_async_stream_guard(
tracer = get_tracer(__name__)
# Create a new span and link it to the previous span
with tracer.start_as_current_span(
"new_guard_span", # type: ignore
"async_stream_span", # type: ignore
links=[Link(guard_span.get_span_context())],
) as new_span:
guard_span = new_span

add_guard_attributes(guard_span, history, res)
add_user_attributes(guard_span)
yield res
yield res
except StopIteration:
next_exists = False
except StopAsyncIteration:
Expand Down
5 changes: 5 additions & 0 deletions guardrails/telemetry/runner_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ def trace_call_wrapper(*args, **kwargs):
) as call_span:
try:
response = fn(*args, **kwargs)
if isinstance(response, LLMResponse) and (
response.async_stream_output or response.stream_output
):
# TODO: Iterate, add a call attr each time
return response
add_call_attributes(call_span, response, *args, **kwargs)
return response
except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions guardrails/validator_service/validator_service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def run_validator(

# requires at least 2 validators
def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]:
if len(new_values) == 0:
return original
current = new_values.pop()
while len(new_values) > 0:
nextval = new_values.pop()
Expand Down
Loading