diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 24416eac..07ce1a65 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -35,7 +35,7 @@ ) AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace" -GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAssessment" +GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAction" HUMAN_PROMPT = "\n\nHuman:" ASSISTANT_PROMPT = "\n\nAssistant:" ALTERNATION_ERROR = ( @@ -298,7 +298,7 @@ def prepare_output_stream( response: Any, stop: Optional[List[str]] = None, messages_api: bool = False, - ) -> Iterator[GenerationChunk]: + ) -> Union[Iterator[GenerationChunk], Iterator[Dict]]: stream = response.get("body") if not stream: @@ -333,7 +333,7 @@ def prepare_output_stream( return elif messages_api and (chunk_obj.get("type") == "message_stop"): - return + yield chunk_obj generation_chunk = _stream_response_to_generation_chunk( chunk_obj, @@ -722,7 +722,7 @@ def _get_bedrock_services_signal(self, body: dict) -> dict: } def _is_guardrails_intervention(self, body: dict) -> bool: - return body.get(GUARDRAILS_BODY_KEY) == "GUARDRAIL_INTERVENED" + return body.get(GUARDRAILS_BODY_KEY) == "INTERVENED" def _prepare_input_and_invoke_stream( self, @@ -795,17 +795,26 @@ def _prepare_input_and_invoke_stream( raise ValueError(f"Error raised by bedrock service: {e}") for chunk in LLMInputOutputAdapter.prepare_output_stream( - provider, - response, - stop, - True if messages else False, + provider, response, stop, True if messages else False ): - yield chunk - # verify and raise callback error if any middleware intervened - self._get_bedrock_services_signal(chunk.generation_info) # type: ignore[arg-type] + if isinstance(chunk, dict): + services_trace = self._get_bedrock_services_signal(chunk) + print(f"services_trace: {services_trace}") + + if run_manager is not None and services_trace.get("signal"): + run_manager.on_llm_error( + Exception( + f"Error raised by bedrock service: {services_trace.get('reason')}" + ), + **services_trace, + ) + else: + yield chunk + # verify and raise callback error if any middleware intervened + self._get_bedrock_services_signal(chunk.generation_info) # type: ignore[arg-type] - if run_manager is not None: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) + if run_manager is not None: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) async def _aprepare_input_and_invoke_stream( self,