diff --git a/chatsky/core/ctx_utils.py b/chatsky/core/ctx_utils.py index a00738e134..8ce962cf7d 100644 --- a/chatsky/core/ctx_utils.py +++ b/chatsky/core/ctx_utils.py @@ -64,6 +64,11 @@ class FrameworkData(BaseModel, arbitrary_types_allowed=True): "Enables complex stats collection across multiple turns." slot_manager: SlotManager = Field(default_factory=SlotManager) "Stores extracted slots." + response_exception: Optional[str] = Field(default=None, exclude=True) + """ + Stores exception messages raised from response functions wrapped in + :py:class:`~chatsky.processing.standard.AddFallbackResponses`. + """ class ContextMainInfo(BaseModel): diff --git a/chatsky/processing/__init__.py b/chatsky/processing/__init__.py index fcd984a05d..888bea2803 100644 --- a/chatsky/processing/__init__.py +++ b/chatsky/processing/__init__.py @@ -1,2 +1,2 @@ -from .standard import ModifyResponse +from .standard import ModifyResponse, AddFallbackResponses from .slots import Extract, Unset, UnsetAll, FillTemplate diff --git a/chatsky/processing/standard.py b/chatsky/processing/standard.py index 8b3fa2aab3..1b9826fb58 100644 --- a/chatsky/processing/standard.py +++ b/chatsky/processing/standard.py @@ -7,8 +7,13 @@ """ import abc +import logging +from typing import Literal, Union, Dict +from pydantic import field_validator -from chatsky.core import BaseProcessing, BaseResponse, Context, MessageInitTypes +from chatsky.core import BaseProcessing, BaseResponse, Context, MessageInitTypes, AnyResponse + +logger = logging.getLogger(__name__) class ModifyResponse(BaseProcessing, abc.ABC): @@ -24,6 +29,8 @@ async def modified_response(self, original_response: BaseResponse, ctx: Context) :param original_response: Response of the current node when :py:class:`.ModifyResponse` is called. :param ctx: Current context. + + :return: Message to replace original response with. """ raise NotImplementedError @@ -39,3 +46,70 @@ async def call(self, ctx: Context) -> MessageInitTypes: return await processing_object.modified_response(current_response, ctx) ctx.current_node.response = ModifiedResponse() + + +class AddFallbackResponses(ModifyResponse): + """ + ModifyResponse to handlie exceptions dynamically using a user-provided + dictionary of exception-to-response mappings. + When an exception occurs, its string representation is stored in + :py:attr:`ctx.framework_data.response_exception`, and a corresponding + fallback response is used. + + Example: + + .. code-block:: python + + # Usage example + + PRE_RESPONSE: { + "add_fallback_responses": AddFallbackResponses( + exception_responses={ + "OverflowError": "Overflow!", + "ValueError": MyResponse(), + "Else": "Other exception occured", + } + ) + } + + """ + + exception_responses: Dict[Union[str, Literal["Else"]], AnyResponse] + """ + Dictionary mapping exception types to fallback responses. + """ + + @field_validator("exception_responses") + @classmethod + def validate_not_empty(cls, exception_responses: dict) -> dict: + """ + Validate that the `exception_responses` dictionary is not empty. + + :param exception_responses: Dictionary mapping exception types to fallback responses. + :raises ValueError: If the `exception_responses` dictionary is empty. + :return: Not empty dictionary of exception_responses. + """ + if len(exception_responses) == 0: + raise ValueError("Exceptions dict is empty") + return exception_responses + + async def modified_response(self, original_response: BaseResponse, ctx: Context) -> MessageInitTypes: + """ + Catch response errors and process them based on `exception_responses` dictionary. + + :param original_response: The original response of the current node. + :param ctx: The current context. + + :return: Message to replace original response with. + """ + result = await original_response.wrapped_call(ctx) + if isinstance(result, Exception): + exception_response = self.exception_responses.get( + type(result).__name__, self.exception_responses.get("Else") + ) + ctx.framework_data.response_exception = repr(result) + if exception_response is None: + raise result + return await exception_response(ctx) + else: + return result diff --git a/tests/core/test_processing.py b/tests/core/test_processing.py index 0c3e7c509a..ddb568eb7b 100644 --- a/tests/core/test_processing.py +++ b/tests/core/test_processing.py @@ -1,3 +1,5 @@ +import pytest + from chatsky import proc, Context, BaseResponse, MessageInitTypes, Message from chatsky.core.script import Node @@ -22,3 +24,51 @@ async def modified_response(self, original_response: BaseResponse, ctx: Context) assert ctx.current_node.response.__class__.__name__ == "ModifiedResponse" assert await ctx.current_node.response(ctx) == Message(misc={"msg": Message("hi")}) + + +class TestAddFallbackResponses: + """ + A class to group and test the functionality of FallbackResponse. + """ + + class ReturnException(BaseResponse): + async def call(self, ctx: Context): + return ctx.framework_data.response_exception + + class RaiseException(BaseResponse, arbitrary_types_allowed=True): + exception: Exception + + async def call(self, ctx: Context): + raise self.exception + + @pytest.mark.parametrize( + "response_with_exception, expected_response", + [ + (RaiseException(exception=OverflowError()), "Overflow!"), + (RaiseException(exception=KeyError()), "Other exception occured"), + (RaiseException(exception=ValueError("some text")), "ValueError('some text')"), + ], + ) + @pytest.mark.asyncio + async def test_fallback_response(self, response_with_exception, expected_response): + ctx = Context() + ctx.framework_data.current_node = Node() + + exceptions = { + "OverflowError": "Overflow!", + "ValueError": self.ReturnException(), + "Else": "Other exception occured", + } + + fallback_response = proc.AddFallbackResponses(exception_responses=exceptions) + ctx.current_node.response = response_with_exception + await fallback_response(ctx) + assert await ctx.current_node.response(ctx) == Message(text=expected_response) + + async def test_fallback_empty_exceptions(self): + ctx = Context() + ctx.framework_data.current_node = Node() + + exceptions = {} + with pytest.raises(ValueError, match="Exceptions dict is empty"): + proc.AddFallbackResponses(exception_responses=exceptions) diff --git a/tutorials/script/core/7_pre_response_processing.py b/tutorials/script/core/7_pre_response_processing.py index 17f55b0420..e4e73bb355 100644 --- a/tutorials/script/core/7_pre_response_processing.py +++ b/tutorials/script/core/7_pre_response_processing.py @@ -97,10 +97,12 @@ async def modified_response(self, original_response, ctx): except Exception as exc: return str(exc) +Note: the functionality of adding custom responses for exceptions +is available in the core library as the +%mddoclink(api,processing.standard,AddFallbackResponses) class. """ - # %% toy_script = { "root": {