diff --git a/guardrails/api_client.py b/guardrails/api_client.py index 6c16f8725..a3a71fd4c 100644 --- a/guardrails/api_client.py +++ b/guardrails/api_client.py @@ -104,6 +104,8 @@ def stream_validate( ) if line: json_output = json.loads(line) + if json_output.get("error"): + raise Exception(json_output.get("error").get("message")) yield IValidationOutcome.from_dict(json_output) def get_history(self, guard_name: str, call_id: str): diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index 68b1f6da7..a8e9c5fd8 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -574,13 +574,15 @@ async def _stream_server_call( validated_output=validated_output, validation_passed=(validation_output.validation_passed is True), ) - if validation_output: - guard_history = self._api_client.get_history( - self.name, validation_output.call_id - ) - self.history.extend( - [Call.from_interface(call) for call in guard_history] - ) + # TODO re-enable this once we have a way to get history + # from a multi-node server + # if validation_output: + # guard_history = self._api_client.get_history( + # self.name, validation_output.call_id + # ) + # self.history.extend( + # [Call.from_interface(call) for call in guard_history] + # ) else: raise ValueError("AsyncGuard does not have an api client!") diff --git a/guardrails/classes/llm/llm_response.py b/guardrails/classes/llm/llm_response.py index 2a0467b64..cf32a9cc4 100644 --- a/guardrails/classes/llm/llm_response.py +++ b/guardrails/classes/llm/llm_response.py @@ -47,7 +47,12 @@ def to_interface(self) -> ILLMResponse: stream_output = [str(so) for so in copy_2] async_stream_output = None - if self.async_stream_output: + # dont do this again if already aiter-able were updating + # ourselves here so in memory + # this can cause issues + if self.async_stream_output and not hasattr( + self.async_stream_output, "__aiter__" + ): # tee doesn't work with async iterators # This may be destructive async_stream_output = [] diff --git a/guardrails/cli/create.py b/guardrails/cli/create.py index 4668236ba..0048f44f2 100644 --- a/guardrails/cli/create.py +++ b/guardrails/cli/create.py @@ -118,7 +118,7 @@ def generate_template_config( guard_instantiations = [] for i, guard in enumerate(template["guards"]): - guard_instantiations.append(f"guard{i} = Guard.from_dict(guards[{i}])") + guard_instantiations.append(f"guard{i} = AsyncGuard.from_dict(guards[{i}])") guard_instantiations = "\n".join(guard_instantiations) # Interpolate variables output_content = template_content.format( diff --git a/guardrails/cli/hub/template_config.py.template b/guardrails/cli/hub/template_config.py.template index 5d5683bdf..b61ffb2f2 100644 --- a/guardrails/cli/hub/template_config.py.template +++ b/guardrails/cli/hub/template_config.py.template @@ -1,6 +1,6 @@ import json import os -from guardrails import Guard +from guardrails import AsyncGuard, Guard from guardrails.hub import {VALIDATOR_IMPORTS} try: diff --git a/guardrails/guard.py b/guardrails/guard.py index 9f95c451d..200b83e3b 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1215,10 +1215,12 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O error="The response from the server was empty!", ) - guard_history = self._api_client.get_history( - self.name, validation_output.call_id - ) - self.history.extend([Call.from_interface(call) for call in guard_history]) + # TODO reenable this when we have history support in + # multi-node server environments + # guard_history = self._api_client.get_history( + # self.name, validation_output.call_id + # ) + # self.history.extend([Call.from_interface(call) for call in guard_history]) validation_summaries = [] if self.history.last and self.history.last.iterations.last: @@ -1281,13 +1283,15 @@ def _stream_server_call( validated_output=validated_output, validation_passed=(validation_output.validation_passed is True), ) - if validation_output: - guard_history = self._api_client.get_history( - self.name, validation_output.call_id - ) - self.history.extend( - [Call.from_interface(call) for call in guard_history] - ) + + # TODO reenable this when sever supports multi-node history + # if validation_output: + # guard_history = self._api_client.get_history( + # self.name, validation_output.call_id + # ) + # self.history.extend( + # [Call.from_interface(call) for call in guard_history] + # ) else: raise ValueError("Guard does not have an api client!") diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index bbe3e21a1..8b22cec8e 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -498,6 +498,10 @@ def _invoke_llm( ), ) + # these are gr only and should not be getting passed to llms + kwargs.pop("reask_prompt", None) + kwargs.pop("reask_instructions", None) + response = completion( model=model, *args, @@ -1088,6 +1092,10 @@ async def invoke_llm( ), ) + # these are gr only and should not be getting passed to llms + kwargs.pop("reask_prompt", None) + kwargs.pop("reask_instructions", None) + response = await acompletion( *args, **kwargs, diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py index aa1b50287..dcd6a7a06 100644 --- a/guardrails/run/async_stream_runner.py +++ b/guardrails/run/async_stream_runner.py @@ -148,6 +148,7 @@ async def async_step( _ = self.is_last_chunk(chunk, api) fragment += chunk_text + results = await validator_service.async_partial_validate( chunk_text, self.metadata, @@ -157,7 +158,8 @@ async def async_step( "$", True, ) - validators = self.validation_map["$"] or [] + validators = self.validation_map.get("$", []) + # collect the result validated_chunk into validation progress # per validator for result in results: @@ -210,7 +212,7 @@ async def async_step( validation_progress[validator_log.validator_name] += chunk # if there is an entry for every validator # run a merge and emit a validation outcome - if len(validation_progress) == len(validators): + if len(validation_progress) == len(validators) or len(validators) == 0: if refrain_triggered: current = "" else: diff --git a/guardrails/telemetry/guard_tracing.py b/guardrails/telemetry/guard_tracing.py index cd850acd5..1eaf70338 100644 --- a/guardrails/telemetry/guard_tracing.py +++ b/guardrails/telemetry/guard_tracing.py @@ -145,9 +145,18 @@ def trace_stream_guard( res = next(result) # type: ignore # FIXME: This should only be called once; # Accumulate the validated output and call at the end - add_guard_attributes(guard_span, history, res) - add_user_attributes(guard_span) - yield res + if not guard_span.is_recording(): + # Assuming you have a tracer instance + tracer = get_tracer(__name__) + # Create a new span and link it to the previous span + with tracer.start_as_current_span( + "stream_guard_span", # type: ignore + links=[Link(guard_span.get_span_context())], + ) as new_span: + guard_span = new_span + add_guard_attributes(guard_span, history, res) + add_user_attributes(guard_span) + yield res except StopIteration: next_exists = False @@ -180,6 +189,7 @@ def trace_guard_execution( result, ValidationOutcome ): return trace_stream_guard(guard_span, result, history) + add_guard_attributes(guard_span, history, result) add_user_attributes(guard_span) return result @@ -204,14 +214,14 @@ async def trace_async_stream_guard( tracer = get_tracer(__name__) # Create a new span and link it to the previous span with tracer.start_as_current_span( - "new_guard_span", # type: ignore + "async_stream_span", # type: ignore links=[Link(guard_span.get_span_context())], ) as new_span: guard_span = new_span add_guard_attributes(guard_span, history, res) add_user_attributes(guard_span) - yield res + yield res except StopIteration: next_exists = False except StopAsyncIteration: diff --git a/guardrails/telemetry/runner_tracing.py b/guardrails/telemetry/runner_tracing.py index d45c6ee4f..27cf27a97 100644 --- a/guardrails/telemetry/runner_tracing.py +++ b/guardrails/telemetry/runner_tracing.py @@ -265,6 +265,11 @@ def trace_call_wrapper(*args, **kwargs): ) as call_span: try: response = fn(*args, **kwargs) + if isinstance(response, LLMResponse) and ( + response.async_stream_output or response.stream_output + ): + # TODO: Iterate, add a call attr each time + return response add_call_attributes(call_span, response, *args, **kwargs) return response except Exception as e: diff --git a/guardrails/validator_service/validator_service_base.py b/guardrails/validator_service/validator_service_base.py index 3b626c7ce..6f1e21754 100644 --- a/guardrails/validator_service/validator_service_base.py +++ b/guardrails/validator_service/validator_service_base.py @@ -169,6 +169,8 @@ def run_validator( # requires at least 2 validators def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]: + if len(new_values) == 0: + return original current = new_values.pop() while len(new_values) > 0: nextval = new_values.pop() diff --git a/server_ci/config.py b/server_ci/config.py index b84d637ea..d5a093ca7 100644 --- a/server_ci/config.py +++ b/server_ci/config.py @@ -1,6 +1,6 @@ import json import os -from guardrails import Guard +from guardrails import AsyncGuard try: file_path = os.path.join(os.getcwd(), "guard-template.json") @@ -11,4 +11,4 @@ SystemExit(1) # instantiate guards -guard0 = Guard.from_dict(guards[0]) +guard0 = AsyncGuard.from_dict(guards[0]) diff --git a/server_ci/tests/test_server.py b/server_ci/tests/test_server.py index 9d51fde0f..041531d78 100644 --- a/server_ci/tests/test_server.py +++ b/server_ci/tests/test_server.py @@ -1,7 +1,7 @@ import openai import os import pytest -from guardrails import Guard, settings +from guardrails import AsyncGuard, Guard, settings # OpenAI compatible Guardrails API Guard openai.base_url = "http://127.0.0.1:8000/guards/test-guard/openai/v1/" @@ -32,6 +32,59 @@ def test_guard_validation(mock_llm_output, validation_output, validation_passed, assert validation_outcome.validated_output == validation_output +@pytest.mark.asyncio +async def test_async_guard_validation(): + settings.use_server = True + guard = AsyncGuard(name="test-guard") + + validation_outcome = await guard( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}], + temperature=0.0, + ) + + assert validation_outcome.validation_passed is True # noqa: E712 + assert validation_outcome.validated_output == "Citrus fruit," + + +@pytest.mark.asyncio +async def test_async_streaming_guard_validation(): + settings.use_server = True + guard = AsyncGuard(name="test-guard") + + async_iterator = await guard( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}], + stream=True, + temperature=0.0, + ) + + full_output = "" + async for validation_chunk in async_iterator: + full_output += validation_chunk.validated_output + + assert full_output == "Citrus fruit,Citrus fruit," + + +@pytest.mark.asyncio +async def test_sync_streaming_guard_validation(): + settings.use_server = True + guard = Guard(name="test-guard") + + iterator = guard( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}], + stream=True, + temperature=0.0, + ) + + full_output = "" + for validation_chunk in iterator: + full_output += validation_chunk.validated_output + + assert full_output == "Citrus fruit,Citrus fruit," + + @pytest.mark.parametrize( "message_content, output, validation_passed, error", [