From 4f8d43efcac145f559151fc7867c10343dbe489d Mon Sep 17 00:00:00 2001 From: Will Gao Date: Mon, 24 Jun 2024 16:38:17 -0400 Subject: [PATCH 1/2] extract bedrock trace and surface error/exception --- libs/aws/langchain_aws/llms/bedrock.py | 48 +++++++++++++++++++------- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 24416eac..33aaf257 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -35,7 +35,7 @@ ) AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace" -GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAssessment" +GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAction" HUMAN_PROMPT = "\n\nHuman:" ASSISTANT_PROMPT = "\n\nAssistant:" ALTERNATION_ERROR = ( @@ -298,7 +298,7 @@ def prepare_output_stream( response: Any, stop: Optional[List[str]] = None, messages_api: bool = False, - ) -> Iterator[GenerationChunk]: + ) -> Union[Iterator[GenerationChunk], Iterator[Dict]]: stream = response.get("body") if not stream: @@ -333,7 +333,7 @@ def prepare_output_stream( return elif messages_api and (chunk_obj.get("type") == "message_stop"): - return + yield chunk_obj generation_chunk = _stream_response_to_generation_chunk( chunk_obj, @@ -396,6 +396,13 @@ async def aprepare_output_stream( continue +class BedrockGuardrailError(Exception): + """Raised when a guardrail is triggered.""" + + def __init__(self): + super().__init__("Blocked by Bedrock Guardrails") + + class BedrockBase(BaseLanguageModel, ABC): """Base class for Bedrock models.""" @@ -693,6 +700,9 @@ def _prepare_input_and_invoke( ), **services_trace, ) + raise Exception( + f"Error raised by bedrock service: {services_trace.get('reason')}" + ) return text, tool_calls, llm_output @@ -722,7 +732,7 @@ def _get_bedrock_services_signal(self, body: dict) -> dict: } def _is_guardrails_intervention(self, body: dict) -> bool: - return body.get(GUARDRAILS_BODY_KEY) == "GUARDRAIL_INTERVENED" + return body.get(GUARDRAILS_BODY_KEY) == "INTERVENED" def _prepare_input_and_invoke_stream( self, @@ -795,17 +805,29 @@ def _prepare_input_and_invoke_stream( raise ValueError(f"Error raised by bedrock service: {e}") for chunk in LLMInputOutputAdapter.prepare_output_stream( - provider, - response, - stop, - True if messages else False, + provider, response, stop, True if messages else False ): - yield chunk - # verify and raise callback error if any middleware intervened - self._get_bedrock_services_signal(chunk.generation_info) # type: ignore[arg-type] + if isinstance(chunk, dict): + services_trace = self._get_bedrock_services_signal(chunk) + print(f"services_trace: {services_trace}") + + if run_manager is not None and services_trace.get("signal"): + run_manager.on_llm_error( + Exception( + f"Error raised by bedrock service: {services_trace.get('reason')}" + ), + **services_trace, + ) + raise Exception( + f"Error raised by bedrock service: {services_trace.get('reason')}" + ) + else: + yield chunk + # verify and raise callback error if any middleware intervened + self._get_bedrock_services_signal(chunk.generation_info) # type: ignore[arg-type] - if run_manager is not None: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) + if run_manager is not None: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) async def _aprepare_input_and_invoke_stream( self, From 82a7c7ce8771681da5ba5c6a8d4d6e6fdbfd55f6 Mon Sep 17 00:00:00 2001 From: Will Gao Date: Tue, 25 Jun 2024 11:51:43 -0400 Subject: [PATCH 2/2] let callbacks handle exception --- libs/aws/langchain_aws/llms/bedrock.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 33aaf257..07ce1a65 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -396,13 +396,6 @@ async def aprepare_output_stream( continue -class BedrockGuardrailError(Exception): - """Raised when a guardrail is triggered.""" - - def __init__(self): - super().__init__("Blocked by Bedrock Guardrails") - - class BedrockBase(BaseLanguageModel, ABC): """Base class for Bedrock models.""" @@ -700,9 +693,6 @@ def _prepare_input_and_invoke( ), **services_trace, ) - raise Exception( - f"Error raised by bedrock service: {services_trace.get('reason')}" - ) return text, tool_calls, llm_output @@ -818,9 +808,6 @@ def _prepare_input_and_invoke_stream( ), **services_trace, ) - raise Exception( - f"Error raised by bedrock service: {services_trace.get('reason')}" - ) else: yield chunk # verify and raise callback error if any middleware intervened