diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index 9af740711..d21af7a3f 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -369,7 +369,11 @@ async def _exec( output=llm_output, base_model=self._base_model, full_schema_reask=full_schema_reask, - disable_tracer=(not self._allow_metrics_collection), + disable_tracer=( + not self._allow_metrics_collection + if isinstance(self._allow_metrics_collection, bool) + else None + ), exec_options=self._exec_opts, ) # Here we have an async generator @@ -391,7 +395,11 @@ async def _exec( output=llm_output, base_model=self._base_model, full_schema_reask=full_schema_reask, - disable_tracer=(not self._allow_metrics_collection), + disable_tracer=( + not self._allow_metrics_collection + if isinstance(self._allow_metrics_collection, bool) + else None + ), exec_options=self._exec_opts, ) # Why are we using a different method here instead of just overriding? diff --git a/guardrails/cli/hub/install.py b/guardrails/cli/hub/install.py index 1a0cb5e16..4c492e1f2 100644 --- a/guardrails/cli/hub/install.py +++ b/guardrails/cli/hub/install.py @@ -10,9 +10,14 @@ from guardrails.cli.version import version_warnings_if_applicable -@hub_command.command() +# Quick note: This is the command for `guardrails hub install`. We change the name of +# the function def to prevent confusion, lest people import it directly and calling it +# with a string for package_uris instead of a list, which behaves oddly. If you need to +# call install from a script, please consider importing install from guardrails, +# not guardrails.cli.hub.install. +@hub_command.command(name="install") @trace(name="guardrails-cli/hub/install") -def install( +def install_cli( package_uris: List[str] = typer.Argument( ..., help="URIs to the packages to install. Example: hub://guardrails/regex_match hub://guardrails/toxic_language", @@ -33,6 +38,17 @@ def install( ), ): try: + if isinstance(package_uris, str): + logger.error( + f"`install` in {__file__} was called with a string instead of " + "a list! This can happen if it is invoked directly instead of " + "being run via the CLI. Did you mean to import `from guardrails import " + "install` instead? Recovering..." + ) + package_uris = [ + package_uris, + ] + from guardrails.hub.install import install_multiple def confirm(): diff --git a/guardrails/guard.py b/guardrails/guard.py index 8f3fd4332..67b9dfc03 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -865,7 +865,11 @@ def _exec( output=llm_output, base_model=self._base_model, full_schema_reask=full_schema_reask, - disable_tracer=(not self._allow_metrics_collection), + disable_tracer=( + not self._allow_metrics_collection + if isinstance(self._allow_metrics_collection, bool) + else None + ), exec_options=self._exec_opts, ) return runner(call_log=call_log, prompt_params=prompt_params) @@ -884,7 +888,11 @@ def _exec( output=llm_output, base_model=self._base_model, full_schema_reask=full_schema_reask, - disable_tracer=(not self._allow_metrics_collection), + disable_tracer=( + not self._allow_metrics_collection + if isinstance(self._allow_metrics_collection, bool) + else None + ), exec_options=self._exec_opts, ) call = runner(call_log=call_log, prompt_params=prompt_params) diff --git a/guardrails/hub/install.py b/guardrails/hub/install.py index 046ee073f..677e7c392 100644 --- a/guardrails/hub/install.py +++ b/guardrails/hub/install.py @@ -53,7 +53,7 @@ def install( Examples: >>> RegexMatch = install("hub://guardrails/regex_match").RegexMatch - >>> install("hub://guardrails/regex_match); + >>> install("hub://guardrails/regex_match") >>> import guardrails.hub.regex_match as regex_match """ diff --git a/guardrails/hub_telemetry/hub_tracing.py b/guardrails/hub_telemetry/hub_tracing.py index a02b1b688..a520a104f 100644 --- a/guardrails/hub_telemetry/hub_tracing.py +++ b/guardrails/hub_telemetry/hub_tracing.py @@ -3,6 +3,7 @@ Any, Dict, Optional, + AsyncGenerator, ) from opentelemetry.trace import Span @@ -224,7 +225,7 @@ def wrapper(*args, **kwargs): return decorator -async def _run_async_gen(fn, *args, **kwargs): +async def _run_async_gen(fn, *args, **kwargs) -> AsyncGenerator[Any, None]: gen = fn(*args, **kwargs) async for item in gen: yield item @@ -238,7 +239,7 @@ def async_trace_stream( ): def decorator(fn): @wraps(fn) - async def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs): hub_telemetry = HubTelemetry() if hub_telemetry._enabled and hub_telemetry._tracer is not None: with hub_telemetry._tracer.start_span( @@ -252,7 +253,7 @@ async def wrapper(*args, **kwargs): nonlocal origin origin = origin if origin is not None else name add_attributes(span, attrs, name, origin, *args, **kwargs) - return _run_async_gen(fn, *args, **kwargs) + return fn(*args, **kwargs) else: return fn(*args, **kwargs) diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py index 8f39c21f2..29b79fbd0 100644 --- a/guardrails/run/async_stream_runner.py +++ b/guardrails/run/async_stream_runner.py @@ -153,6 +153,10 @@ async def async_step( validate_subschema=True, stream=True, ) + # TODO why? how does it happen in the other places we handle streams + if validated_fragment is None: + validated_fragment = "" + if isinstance(validated_fragment, SkeletonReAsk): raise ValueError( "Received fragment schema is an invalid sub-schema " @@ -165,7 +169,7 @@ async def async_step( "Reasks are not yet supported with streaming. Please " "remove reasks from schema or disable streaming." ) - validation_response += cast(str, validated_fragment) + validation_response += validated_fragment passed = call_log.status == pass_status yield ValidationOutcome( call_id=call_log.id, # type: ignore diff --git a/guardrails/telemetry/guard_tracing.py b/guardrails/telemetry/guard_tracing.py index 31e074b1b..cd850acd5 100644 --- a/guardrails/telemetry/guard_tracing.py +++ b/guardrails/telemetry/guard_tracing.py @@ -11,7 +11,7 @@ ) from opentelemetry import context, trace -from opentelemetry.trace import StatusCode, Tracer, Span +from opentelemetry.trace import StatusCode, Tracer, Span, Link, get_tracer from guardrails.settings import settings from guardrails.classes.generic.stack import Stack @@ -22,6 +22,10 @@ from guardrails.telemetry.common import add_user_attributes from guardrails.version import GUARDRAILS_VERSION +import sys + +if sys.version_info.minor < 10: + from guardrails.utils.polyfills import anext # from sentence_transformers import SentenceTransformer # import numpy as np @@ -195,8 +199,18 @@ async def trace_async_stream_guard( while next_exists: try: res = await anext(result) # type: ignore - add_guard_attributes(guard_span, history, res) - add_user_attributes(guard_span) + if not guard_span.is_recording(): + # Assuming you have a tracer instance + tracer = get_tracer(__name__) + # Create a new span and link it to the previous span + with tracer.start_as_current_span( + "new_guard_span", # type: ignore + links=[Link(guard_span.get_span_context())], + ) as new_span: + guard_span = new_span + + add_guard_attributes(guard_span, history, res) + add_user_attributes(guard_span) yield res except StopIteration: next_exists = False diff --git a/guardrails/telemetry/runner_tracing.py b/guardrails/telemetry/runner_tracing.py index 5a6ef6f85..d45c6ee4f 100644 --- a/guardrails/telemetry/runner_tracing.py +++ b/guardrails/telemetry/runner_tracing.py @@ -21,6 +21,10 @@ from guardrails.utils.safe_get import safe_get from guardrails.version import GUARDRAILS_VERSION +import sys + +if sys.version_info.minor < 10: + from guardrails.utils.polyfills import anext ######################################### ### START Runner.step Instrumentation ### diff --git a/guardrails/utils/hub_telemetry_utils.py b/guardrails/utils/hub_telemetry_utils.py index d8e5d2dc9..d14dcaf6d 100644 --- a/guardrails/utils/hub_telemetry_utils.py +++ b/guardrails/utils/hub_telemetry_utils.py @@ -57,7 +57,6 @@ def initialize_tracer( """Initializes a tracer for Guardrails Hub.""" if enabled is None: enabled = settings.rc.enable_metrics or False - self._enabled = enabled self._carrier = {} self._service_name = service_name diff --git a/tests/integration_tests/test_async_streaming.py b/tests/integration_tests/test_async_streaming.py new file mode 100644 index 000000000..b0f2ed300 --- /dev/null +++ b/tests/integration_tests/test_async_streaming.py @@ -0,0 +1,184 @@ +# 3 tests +# 1. Test streaming with OpenAICallable (mock openai.Completion.create) +# 2. Test streaming with OpenAIChatCallable (mock openai.ChatCompletion.create) +# 3. Test string schema streaming +# Using the LowerCase Validator, and a custom validator to show new streaming behavior +from typing import Any, Callable, Dict, List, Optional, Union + +import asyncio +import pytest + +import guardrails as gd +from guardrails.utils.casting_utils import to_int +from guardrails.validator_base import ( + ErrorSpan, + FailResult, + OnFailAction, + PassResult, + ValidationResult, + Validator, + register_validator, +) +from tests.integration_tests.test_assets.validators import LowerCase, MockDetectPII + + +@register_validator(name="minsentencelength", data_type=["string", "list"]) +class MinSentenceLengthValidator(Validator): + def __init__( + self, + min: Optional[int] = None, + max: Optional[int] = None, + on_fail: Optional[Callable] = None, + ): + super().__init__( + on_fail=on_fail, + min=min, + max=max, + ) + self._min = to_int(min) + self._max = to_int(max) + + def sentence_split(self, value): + return list(map(lambda x: x + ".", value.split(".")[:-1])) + + def validate(self, value: Union[str, List], metadata: Dict) -> ValidationResult: + sentences = self.sentence_split(value) + error_spans = [] + index = 0 + for sentence in sentences: + if len(sentence) < self._min: + error_spans.append( + ErrorSpan( + start=index, + end=index + len(sentence), + reason=f"Sentence has length less than {self._min}. " + f"Please return a longer output, " + f"that is shorter than {self._max} characters.", + ) + ) + if len(sentence) > self._max: + error_spans.append( + ErrorSpan( + start=index, + end=index + len(sentence), + reason=f"Sentence has length greater than {self._max}. " + f"Please return a shorter output, " + f"that is shorter than {self._max} characters.", + ) + ) + index = index + len(sentence) + if len(error_spans) > 0: + return FailResult( + validated_chunk=value, + error_spans=error_spans, + error_message=f"Sentence has length less than {self._min}. " + f"Please return a longer output, " + f"that is shorter than {self._max} characters.", + ) + return PassResult(validated_chunk=value) + + def validate_stream(self, chunk: Any, metadata: Dict, **kwargs) -> ValidationResult: + return super().validate_stream(chunk, metadata, **kwargs) + + +class Delta: + content: str + + def __init__(self, content): + self.content = content + + +class Choice: + text: str + finish_reason: str + index: int + delta: Delta + + def __init__(self, text, delta, finish_reason, index=0): + self.index = index + self.delta = delta + self.text = text + self.finish_reason = finish_reason + + +class MockOpenAIV1ChunkResponse: + choices: list + model: str + + def __init__(self, choices, model): + self.choices = choices + self.model = model + + +class Response: + def __init__(self, chunks): + self.chunks = chunks + + async def gen(): + for chunk in self.chunks: + yield MockOpenAIV1ChunkResponse( + choices=[ + Choice( + delta=Delta(content=chunk), + text=chunk, + finish_reason=None, + ) + ], + model="OpenAI model name", + ) + await asyncio.sleep(0) # Yield control to the event loop + + self.completion_stream = gen() + + +POETRY_CHUNKS = [ + "John, under ", + "GOLDEN bridges", + ", roams,\n", + "SAN Francisco's ", + "hills, his HOME.\n", + "Dreams of", + " FOG, and salty AIR,\n", + "In his HEART", + ", he's always THERE.", +] + + +@pytest.mark.asyncio +async def test_filter_behavior(mocker): + mocker.patch( + "litellm.acompletion", + return_value=Response(POETRY_CHUNKS), + ) + + guard = gd.AsyncGuard().use_many( + MockDetectPII( + on_fail=OnFailAction.FIX, + pii_entities="pii", + replace_map={"John": "", "SAN Francisco's": ""}, + ), + LowerCase(on_fail=OnFailAction.FILTER), + ) + prompt = """Write me a 4 line poem about John in San Francisco. + Make every third word all caps.""" + gen = await guard( + model="gpt-3.5-turbo", + max_tokens=10, + temperature=0, + stream=True, + prompt=prompt, + ) + + text = "" + final_res = None + async for res in gen: + final_res = res + text += res.validated_output + + assert final_res.raw_llm_output == ", he's always THERE." + # TODO deep dive this + assert text == ( + "John, under GOLDEN bridges, roams,\n" + "SAN Francisco's Dreams of FOG, and salty AIR,\n" + "In his HEART" + )