From 04e2b84217786646baf6b1f2be2e63badd5a8832 Mon Sep 17 00:00:00 2001 From: zsimjee Date: Tue, 8 Oct 2024 10:00:57 -0700 Subject: [PATCH 01/13] fix stream generation --- guardrails/telemetry/guard_tracing.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/guardrails/telemetry/guard_tracing.py b/guardrails/telemetry/guard_tracing.py index cd850acd5..b6dbb604b 100644 --- a/guardrails/telemetry/guard_tracing.py +++ b/guardrails/telemetry/guard_tracing.py @@ -179,10 +179,11 @@ def trace_guard_execution( if isinstance(result, Iterator) and not isinstance( result, ValidationOutcome ): - return trace_stream_guard(guard_span, result, history) - add_guard_attributes(guard_span, history, result) - add_user_attributes(guard_span) - return result + for res in trace_stream_guard(guard_span, result, history): + yield res + else: + add_guard_attributes(guard_span, history, result) + add_user_attributes(guard_span) except Exception as e: guard_span.set_status(status=StatusCode.ERROR, description=str(e)) raise e From 6992b22c34eba50eec80115ecb320e94db2dbf7d Mon Sep 17 00:00:00 2001 From: David Tam Date: Tue, 8 Oct 2024 10:03:46 -0700 Subject: [PATCH 02/13] updates for history, streaming handling and fixes --- guardrails/api_client.py | 2 ++ guardrails/guard.py | 7 ++++--- guardrails/llm_providers.py | 6 ++++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/guardrails/api_client.py b/guardrails/api_client.py index 6c16f8725..967e441b0 100644 --- a/guardrails/api_client.py +++ b/guardrails/api_client.py @@ -104,6 +104,8 @@ def stream_validate( ) if line: json_output = json.loads(line) + if(json_output.get("error")): + raise Exception(json_output.get("error").get("message")) yield IValidationOutcome.from_dict(json_output) def get_history(self, guard_name: str, call_id: str): diff --git a/guardrails/guard.py b/guardrails/guard.py index 9f95c451d..205a0e141 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1285,9 +1285,10 @@ def _stream_server_call( guard_history = self._api_client.get_history( self.name, validation_output.call_id ) - self.history.extend( - [Call.from_interface(call) for call in guard_history] - ) + # TODO renable this. doesnt work in a multiple server environment + # self.history.extend( + # [Call.from_interface(call) for call in guard_history] + # ) else: raise ValueError("Guard does not have an api client!") diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index bbe3e21a1..91ae567dc 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -498,6 +498,9 @@ def _invoke_llm( ), ) + kwargs.pop("reask_prompt", None) + kwargs.pop("reask_instructions", None) + response = completion( model=model, *args, @@ -1088,6 +1091,9 @@ async def invoke_llm( ), ) + kwargs.pop("reask_prompt", None) + kwargs.pop("reask_instructions", None) + response = await acompletion( *args, **kwargs, From db39c0ef6ddf051ea5cc2084a64de4cb2e7de923 Mon Sep 17 00:00:00 2001 From: David Tam Date: Tue, 8 Oct 2024 10:58:32 -0700 Subject: [PATCH 03/13] more history updates --- guardrails/guard.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index 205a0e141..061010a9a 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1215,10 +1215,10 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O error="The response from the server was empty!", ) - guard_history = self._api_client.get_history( - self.name, validation_output.call_id - ) - self.history.extend([Call.from_interface(call) for call in guard_history]) + # guard_history = self._api_client.get_history( + # self.name, validation_output.call_id + # ) + # self.history.extend([Call.from_interface(call) for call in guard_history]) validation_summaries = [] if self.history.last and self.history.last.iterations.last: From 59117ccdf2b0ec13f8fcf6a1c2020b3ecefe2538 Mon Sep 17 00:00:00 2001 From: David Tam Date: Tue, 8 Oct 2024 11:34:07 -0700 Subject: [PATCH 04/13] fix span error --- guardrails/telemetry/guard_tracing.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/guardrails/telemetry/guard_tracing.py b/guardrails/telemetry/guard_tracing.py index cd850acd5..abb28b903 100644 --- a/guardrails/telemetry/guard_tracing.py +++ b/guardrails/telemetry/guard_tracing.py @@ -145,8 +145,17 @@ def trace_stream_guard( res = next(result) # type: ignore # FIXME: This should only be called once; # Accumulate the validated output and call at the end - add_guard_attributes(guard_span, history, res) - add_user_attributes(guard_span) + if not guard_span.is_recording(): + # Assuming you have a tracer instance + tracer = get_tracer(__name__) + # Create a new span and link it to the previous span + with tracer.start_as_current_span( + "new_guard_span", # type: ignore + links=[Link(guard_span.get_span_context())], + ) as new_span: + guard_span = new_span + add_guard_attributes(guard_span, history, res) + add_user_attributes(guard_span) yield res except StopIteration: next_exists = False From 32a2b13651a8b920b4f14bd5b57c892abb0d320c Mon Sep 17 00:00:00 2001 From: David Tam Date: Tue, 8 Oct 2024 11:39:49 -0700 Subject: [PATCH 05/13] fix async history error --- guardrails/guard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index 061010a9a..0131fbf77 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1370,7 +1370,7 @@ def to_dict(self) -> Dict[str, Any]: description=self.description, validators=self.validators, output_schema=self.output_schema, - history=[c.to_interface() for c in self.history], # type: ignore + history=[], # type: ignore ) return i_guard.to_dict() From 78b4e7d427e4a083d13171703ea523c31d38d7d0 Mon Sep 17 00:00:00 2001 From: zsimjee Date: Tue, 8 Oct 2024 12:39:00 -0700 Subject: [PATCH 06/13] one chunk at a time --- guardrails/telemetry/runner_tracing.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/guardrails/telemetry/runner_tracing.py b/guardrails/telemetry/runner_tracing.py index d45c6ee4f..27cf27a97 100644 --- a/guardrails/telemetry/runner_tracing.py +++ b/guardrails/telemetry/runner_tracing.py @@ -265,6 +265,11 @@ def trace_call_wrapper(*args, **kwargs): ) as call_span: try: response = fn(*args, **kwargs) + if isinstance(response, LLMResponse) and ( + response.async_stream_output or response.stream_output + ): + # TODO: Iterate, add a call attr each time + return response add_call_attributes(call_span, response, *args, **kwargs) return response except Exception as e: From fa6bbdf30cac2fc9f8d6b11d94e091db6cf28f7b Mon Sep 17 00:00:00 2001 From: David Tam Date: Tue, 8 Oct 2024 13:35:03 -0700 Subject: [PATCH 07/13] more updates for perf --- guardrails/async_guard.py | 14 +++--- guardrails/run/async_stream_runner.py | 6 ++- guardrails/telemetry/guard_tracing.py | 44 +++++++++---------- .../validator_service_base.py | 2 + 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index 68b1f6da7..7dabf716a 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -574,13 +574,13 @@ async def _stream_server_call( validated_output=validated_output, validation_passed=(validation_output.validation_passed is True), ) - if validation_output: - guard_history = self._api_client.get_history( - self.name, validation_output.call_id - ) - self.history.extend( - [Call.from_interface(call) for call in guard_history] - ) + # if validation_output: + # guard_history = self._api_client.get_history( + # self.name, validation_output.call_id + # ) + # self.history.extend( + # [Call.from_interface(call) for call in guard_history] + # ) else: raise ValueError("AsyncGuard does not have an api client!") diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py index aa1b50287..dcd6a7a06 100644 --- a/guardrails/run/async_stream_runner.py +++ b/guardrails/run/async_stream_runner.py @@ -148,6 +148,7 @@ async def async_step( _ = self.is_last_chunk(chunk, api) fragment += chunk_text + results = await validator_service.async_partial_validate( chunk_text, self.metadata, @@ -157,7 +158,8 @@ async def async_step( "$", True, ) - validators = self.validation_map["$"] or [] + validators = self.validation_map.get("$", []) + # collect the result validated_chunk into validation progress # per validator for result in results: @@ -210,7 +212,7 @@ async def async_step( validation_progress[validator_log.validator_name] += chunk # if there is an entry for every validator # run a merge and emit a validation outcome - if len(validation_progress) == len(validators): + if len(validation_progress) == len(validators) or len(validators) == 0: if refrain_triggered: current = "" else: diff --git a/guardrails/telemetry/guard_tracing.py b/guardrails/telemetry/guard_tracing.py index 8989308fc..23028ba41 100644 --- a/guardrails/telemetry/guard_tracing.py +++ b/guardrails/telemetry/guard_tracing.py @@ -145,17 +145,17 @@ def trace_stream_guard( res = next(result) # type: ignore # FIXME: This should only be called once; # Accumulate the validated output and call at the end - if not guard_span.is_recording(): - # Assuming you have a tracer instance - tracer = get_tracer(__name__) - # Create a new span and link it to the previous span - with tracer.start_as_current_span( - "new_guard_span", # type: ignore - links=[Link(guard_span.get_span_context())], - ) as new_span: - guard_span = new_span - add_guard_attributes(guard_span, history, res) - add_user_attributes(guard_span) + # if not guard_span.is_recording(): + # # Assuming you have a tracer instance + # tracer = get_tracer(__name__) + # # Create a new span and link it to the previous span + # with tracer.start_as_current_span( + # "new_guard_span", # type: ignore + # links=[Link(guard_span.get_span_context())], + # ) as new_span: + # guard_span = new_span + # add_guard_attributes(guard_span, history, res) + # add_user_attributes(guard_span) yield res except StopIteration: next_exists = False @@ -209,18 +209,18 @@ async def trace_async_stream_guard( while next_exists: try: res = await anext(result) # type: ignore - if not guard_span.is_recording(): + # if not guard_span.is_recording(): # Assuming you have a tracer instance - tracer = get_tracer(__name__) - # Create a new span and link it to the previous span - with tracer.start_as_current_span( - "new_guard_span", # type: ignore - links=[Link(guard_span.get_span_context())], - ) as new_span: - guard_span = new_span - - add_guard_attributes(guard_span, history, res) - add_user_attributes(guard_span) + # tracer = get_tracer(__name__) + # # Create a new span and link it to the previous span + # with tracer.start_as_current_span( + # "new_guard_span", # type: ignore + # links=[Link(guard_span.get_span_context())], + # ) as new_span: + # guard_span = new_span + + # add_guard_attributes(guard_span, history, res) + # add_user_attributes(guard_span) yield res except StopIteration: next_exists = False diff --git a/guardrails/validator_service/validator_service_base.py b/guardrails/validator_service/validator_service_base.py index 3b626c7ce..6f1e21754 100644 --- a/guardrails/validator_service/validator_service_base.py +++ b/guardrails/validator_service/validator_service_base.py @@ -169,6 +169,8 @@ def run_validator( # requires at least 2 validators def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]: + if len(new_values) == 0: + return original current = new_values.pop() while len(new_values) > 0: nextval = new_values.pop() From 3c50ad765c90d552fed54ef8b2ccf829cd653d72 Mon Sep 17 00:00:00 2001 From: David Tam Date: Tue, 8 Oct 2024 13:52:07 -0700 Subject: [PATCH 08/13] cleanup --- guardrails/async_guard.py | 1 + guardrails/classes/llm/llm_response.py | 6 +++++- guardrails/guard.py | 20 +++++++++++--------- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index 7dabf716a..840893f72 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -574,6 +574,7 @@ async def _stream_server_call( validated_output=validated_output, validation_passed=(validation_output.validation_passed is True), ) + # TODO re-enable this once we have a way to get history from a multi-node server # if validation_output: # guard_history = self._api_client.get_history( # self.name, validation_output.call_id diff --git a/guardrails/classes/llm/llm_response.py b/guardrails/classes/llm/llm_response.py index 2a0467b64..afeb74020 100644 --- a/guardrails/classes/llm/llm_response.py +++ b/guardrails/classes/llm/llm_response.py @@ -47,7 +47,11 @@ def to_interface(self) -> ILLMResponse: stream_output = [str(so) for so in copy_2] async_stream_output = None - if self.async_stream_output: + # dont do this again if already aiter-able were updating ourselves here so in memory + # this can cause issues + if self.async_stream_output and not hasattr( + self.async_stream_output, "__aiter__" + ): # tee doesn't work with async iterators # This may be destructive async_stream_output = [] diff --git a/guardrails/guard.py b/guardrails/guard.py index 0131fbf77..267f7a6a2 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1215,6 +1215,7 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O error="The response from the server was empty!", ) + # TODO renable this when we have history support in multi-node server environments # guard_history = self._api_client.get_history( # self.name, validation_output.call_id # ) @@ -1281,14 +1282,15 @@ def _stream_server_call( validated_output=validated_output, validation_passed=(validation_output.validation_passed is True), ) - if validation_output: - guard_history = self._api_client.get_history( - self.name, validation_output.call_id - ) - # TODO renable this. doesnt work in a multiple server environment - # self.history.extend( - # [Call.from_interface(call) for call in guard_history] - # ) + + # TODO reenable this when sever supports multi-node history + # if validation_output: + # guard_history = self._api_client.get_history( + # self.name, validation_output.call_id + # ) + # self.history.extend( + # [Call.from_interface(call) for call in guard_history] + # ) else: raise ValueError("Guard does not have an api client!") @@ -1370,7 +1372,7 @@ def to_dict(self) -> Dict[str, Any]: description=self.description, validators=self.validators, output_schema=self.output_schema, - history=[], # type: ignore + history=[c.to_interface() for c in self.history], # type: ignore ) return i_guard.to_dict() From bdbb82cd0c76f5812157c778e35fe8ec2241742f Mon Sep 17 00:00:00 2001 From: David Tam Date: Wed, 9 Oct 2024 09:45:31 -0700 Subject: [PATCH 09/13] cleanup --- guardrails/llm_providers.py | 2 ++ guardrails/telemetry/guard_tracing.py | 34 +++++++++++++-------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index 91ae567dc..8b22cec8e 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -498,6 +498,7 @@ def _invoke_llm( ), ) + # these are gr only and should not be getting passed to llms kwargs.pop("reask_prompt", None) kwargs.pop("reask_instructions", None) @@ -1091,6 +1092,7 @@ async def invoke_llm( ), ) + # these are gr only and should not be getting passed to llms kwargs.pop("reask_prompt", None) kwargs.pop("reask_instructions", None) diff --git a/guardrails/telemetry/guard_tracing.py b/guardrails/telemetry/guard_tracing.py index 23028ba41..2fafcaab1 100644 --- a/guardrails/telemetry/guard_tracing.py +++ b/guardrails/telemetry/guard_tracing.py @@ -11,7 +11,7 @@ ) from opentelemetry import context, trace -from opentelemetry.trace import StatusCode, Tracer, Span, Link, get_tracer +from opentelemetry.trace import StatusCode, Tracer, Span from guardrails.settings import settings from guardrails.classes.generic.stack import Stack @@ -188,11 +188,11 @@ def trace_guard_execution( if isinstance(result, Iterator) and not isinstance( result, ValidationOutcome ): - for res in trace_stream_guard(guard_span, result, history): - yield res - else: - add_guard_attributes(guard_span, history, result) - add_user_attributes(guard_span) + return trace_stream_guard(guard_span, result, history) + + # add_guard_attributes(guard_span, history, result) + # add_user_attributes(guard_span) + return result except Exception as e: guard_span.set_status(status=StatusCode.ERROR, description=str(e)) raise e @@ -210,17 +210,17 @@ async def trace_async_stream_guard( try: res = await anext(result) # type: ignore # if not guard_span.is_recording(): - # Assuming you have a tracer instance - # tracer = get_tracer(__name__) - # # Create a new span and link it to the previous span - # with tracer.start_as_current_span( - # "new_guard_span", # type: ignore - # links=[Link(guard_span.get_span_context())], - # ) as new_span: - # guard_span = new_span - - # add_guard_attributes(guard_span, history, res) - # add_user_attributes(guard_span) + # Assuming you have a tracer instance + # tracer = get_tracer(__name__) + # # Create a new span and link it to the previous span + # with tracer.start_as_current_span( + # "new_guard_span", # type: ignore + # links=[Link(guard_span.get_span_context())], + # ) as new_span: + # guard_span = new_span + + # add_guard_attributes(guard_span, history, res) + # add_user_attributes(guard_span) yield res except StopIteration: next_exists = False From 50c9b5eeba9a090a5b221a72be90c0cab5c50f45 Mon Sep 17 00:00:00 2001 From: David Tam Date: Wed, 9 Oct 2024 11:02:15 -0700 Subject: [PATCH 10/13] renable some tracing cleanup --- guardrails/telemetry/guard_tracing.py | 52 +++++++++++++-------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/guardrails/telemetry/guard_tracing.py b/guardrails/telemetry/guard_tracing.py index 2fafcaab1..8a65b660f 100644 --- a/guardrails/telemetry/guard_tracing.py +++ b/guardrails/telemetry/guard_tracing.py @@ -11,7 +11,7 @@ ) from opentelemetry import context, trace -from opentelemetry.trace import StatusCode, Tracer, Span +from opentelemetry.trace import StatusCode, Tracer, Span, Link, get_tracer from guardrails.settings import settings from guardrails.classes.generic.stack import Stack @@ -145,18 +145,18 @@ def trace_stream_guard( res = next(result) # type: ignore # FIXME: This should only be called once; # Accumulate the validated output and call at the end - # if not guard_span.is_recording(): - # # Assuming you have a tracer instance - # tracer = get_tracer(__name__) - # # Create a new span and link it to the previous span - # with tracer.start_as_current_span( - # "new_guard_span", # type: ignore - # links=[Link(guard_span.get_span_context())], - # ) as new_span: - # guard_span = new_span - # add_guard_attributes(guard_span, history, res) - # add_user_attributes(guard_span) - yield res + if not guard_span.is_recording(): + # Assuming you have a tracer instance + tracer = get_tracer(__name__) + # Create a new span and link it to the previous span + with tracer.start_as_current_span( + "stream_guard_span", # type: ignore + links=[Link(guard_span.get_span_context())], + ) as new_span: + guard_span = new_span + add_guard_attributes(guard_span, history, res) + add_user_attributes(guard_span) + yield res except StopIteration: next_exists = False @@ -209,19 +209,19 @@ async def trace_async_stream_guard( while next_exists: try: res = await anext(result) # type: ignore - # if not guard_span.is_recording(): - # Assuming you have a tracer instance - # tracer = get_tracer(__name__) - # # Create a new span and link it to the previous span - # with tracer.start_as_current_span( - # "new_guard_span", # type: ignore - # links=[Link(guard_span.get_span_context())], - # ) as new_span: - # guard_span = new_span - - # add_guard_attributes(guard_span, history, res) - # add_user_attributes(guard_span) - yield res + if not guard_span.is_recording(): + # Assuming you have a tracer instance + tracer = get_tracer(__name__) + # Create a new span and link it to the previous span + with tracer.start_as_current_span( + "async_stream_span", # type: ignore + links=[Link(guard_span.get_span_context())], + ) as new_span: + guard_span = new_span + + add_guard_attributes(guard_span, history, res) + add_user_attributes(guard_span) + yield res except StopIteration: next_exists = False except StopAsyncIteration: From 46e1f44328925815022e59de8d30b51538f96491 Mon Sep 17 00:00:00 2001 From: David Tam Date: Wed, 9 Oct 2024 11:08:56 -0700 Subject: [PATCH 11/13] renable last of tracing --- guardrails/telemetry/guard_tracing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/guardrails/telemetry/guard_tracing.py b/guardrails/telemetry/guard_tracing.py index 8a65b660f..1eaf70338 100644 --- a/guardrails/telemetry/guard_tracing.py +++ b/guardrails/telemetry/guard_tracing.py @@ -190,8 +190,8 @@ def trace_guard_execution( ): return trace_stream_guard(guard_span, result, history) - # add_guard_attributes(guard_span, history, result) - # add_user_attributes(guard_span) + add_guard_attributes(guard_span, history, result) + add_user_attributes(guard_span) return result except Exception as e: guard_span.set_status(status=StatusCode.ERROR, description=str(e)) From 5ad9dbdfe2feff0687efd502768ac210e53c47a2 Mon Sep 17 00:00:00 2001 From: David Tam Date: Wed, 9 Oct 2024 15:19:08 -0700 Subject: [PATCH 12/13] get branch green --- guardrails/api_client.py | 2 +- guardrails/async_guard.py | 3 ++- guardrails/classes/llm/llm_response.py | 3 ++- guardrails/guard.py | 3 ++- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/guardrails/api_client.py b/guardrails/api_client.py index 967e441b0..a3a71fd4c 100644 --- a/guardrails/api_client.py +++ b/guardrails/api_client.py @@ -104,7 +104,7 @@ def stream_validate( ) if line: json_output = json.loads(line) - if(json_output.get("error")): + if json_output.get("error"): raise Exception(json_output.get("error").get("message")) yield IValidationOutcome.from_dict(json_output) diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index 840893f72..a8e9c5fd8 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -574,7 +574,8 @@ async def _stream_server_call( validated_output=validated_output, validation_passed=(validation_output.validation_passed is True), ) - # TODO re-enable this once we have a way to get history from a multi-node server + # TODO re-enable this once we have a way to get history + # from a multi-node server # if validation_output: # guard_history = self._api_client.get_history( # self.name, validation_output.call_id diff --git a/guardrails/classes/llm/llm_response.py b/guardrails/classes/llm/llm_response.py index afeb74020..cf32a9cc4 100644 --- a/guardrails/classes/llm/llm_response.py +++ b/guardrails/classes/llm/llm_response.py @@ -47,7 +47,8 @@ def to_interface(self) -> ILLMResponse: stream_output = [str(so) for so in copy_2] async_stream_output = None - # dont do this again if already aiter-able were updating ourselves here so in memory + # dont do this again if already aiter-able were updating + # ourselves here so in memory # this can cause issues if self.async_stream_output and not hasattr( self.async_stream_output, "__aiter__" diff --git a/guardrails/guard.py b/guardrails/guard.py index 267f7a6a2..67378c50b 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1215,7 +1215,8 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O error="The response from the server was empty!", ) - # TODO renable this when we have history support in multi-node server environments + # TODO renable this when we have history support in + # multi-node server environments # guard_history = self._api_client.get_history( # self.name, validation_output.call_id # ) From f3f39d1fd8a771799b3c4c8549a0f9c09fdfa901 Mon Sep 17 00:00:00 2001 From: David Tam Date: Thu, 10 Oct 2024 09:50:29 -0700 Subject: [PATCH 13/13] tests and enable async guards by default from create template --- guardrails/cli/create.py | 2 +- .../cli/hub/template_config.py.template | 2 +- guardrails/guard.py | 2 +- server_ci/config.py | 4 +- server_ci/tests/test_server.py | 55 ++++++++++++++++++- 5 files changed, 59 insertions(+), 6 deletions(-) diff --git a/guardrails/cli/create.py b/guardrails/cli/create.py index 4668236ba..0048f44f2 100644 --- a/guardrails/cli/create.py +++ b/guardrails/cli/create.py @@ -118,7 +118,7 @@ def generate_template_config( guard_instantiations = [] for i, guard in enumerate(template["guards"]): - guard_instantiations.append(f"guard{i} = Guard.from_dict(guards[{i}])") + guard_instantiations.append(f"guard{i} = AsyncGuard.from_dict(guards[{i}])") guard_instantiations = "\n".join(guard_instantiations) # Interpolate variables output_content = template_content.format( diff --git a/guardrails/cli/hub/template_config.py.template b/guardrails/cli/hub/template_config.py.template index 5d5683bdf..b61ffb2f2 100644 --- a/guardrails/cli/hub/template_config.py.template +++ b/guardrails/cli/hub/template_config.py.template @@ -1,6 +1,6 @@ import json import os -from guardrails import Guard +from guardrails import AsyncGuard, Guard from guardrails.hub import {VALIDATOR_IMPORTS} try: diff --git a/guardrails/guard.py b/guardrails/guard.py index 67378c50b..200b83e3b 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1215,7 +1215,7 @@ def _single_server_call(self, *, payload: Dict[str, Any]) -> ValidationOutcome[O error="The response from the server was empty!", ) - # TODO renable this when we have history support in + # TODO reenable this when we have history support in # multi-node server environments # guard_history = self._api_client.get_history( # self.name, validation_output.call_id diff --git a/server_ci/config.py b/server_ci/config.py index b84d637ea..d5a093ca7 100644 --- a/server_ci/config.py +++ b/server_ci/config.py @@ -1,6 +1,6 @@ import json import os -from guardrails import Guard +from guardrails import AsyncGuard try: file_path = os.path.join(os.getcwd(), "guard-template.json") @@ -11,4 +11,4 @@ SystemExit(1) # instantiate guards -guard0 = Guard.from_dict(guards[0]) +guard0 = AsyncGuard.from_dict(guards[0]) diff --git a/server_ci/tests/test_server.py b/server_ci/tests/test_server.py index 9d51fde0f..041531d78 100644 --- a/server_ci/tests/test_server.py +++ b/server_ci/tests/test_server.py @@ -1,7 +1,7 @@ import openai import os import pytest -from guardrails import Guard, settings +from guardrails import AsyncGuard, Guard, settings # OpenAI compatible Guardrails API Guard openai.base_url = "http://127.0.0.1:8000/guards/test-guard/openai/v1/" @@ -32,6 +32,59 @@ def test_guard_validation(mock_llm_output, validation_output, validation_passed, assert validation_outcome.validated_output == validation_output +@pytest.mark.asyncio +async def test_async_guard_validation(): + settings.use_server = True + guard = AsyncGuard(name="test-guard") + + validation_outcome = await guard( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}], + temperature=0.0, + ) + + assert validation_outcome.validation_passed is True # noqa: E712 + assert validation_outcome.validated_output == "Citrus fruit," + + +@pytest.mark.asyncio +async def test_async_streaming_guard_validation(): + settings.use_server = True + guard = AsyncGuard(name="test-guard") + + async_iterator = await guard( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}], + stream=True, + temperature=0.0, + ) + + full_output = "" + async for validation_chunk in async_iterator: + full_output += validation_chunk.validated_output + + assert full_output == "Citrus fruit,Citrus fruit," + + +@pytest.mark.asyncio +async def test_sync_streaming_guard_validation(): + settings.use_server = True + guard = Guard(name="test-guard") + + iterator = guard( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Tell me about Oranges in 5 words"}], + stream=True, + temperature=0.0, + ) + + full_output = "" + for validation_chunk in iterator: + full_output += validation_chunk.validated_output + + assert full_output == "Citrus fruit,Citrus fruit," + + @pytest.mark.parametrize( "message_content, output, validation_passed, error", [