Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion chatsky/core/ctx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from asyncio import Event
from json import loads
from time import time_ns
from typing import Any, Optional, Dict, TYPE_CHECKING
from typing import Any, List, Optional, Dict, Literal, TYPE_CHECKING, Tuple

from pydantic import BaseModel, Field, PrivateAttr, TypeAdapter, field_serializer, field_validator

Expand Down Expand Up @@ -40,6 +40,18 @@ class ServiceState(BaseModel, arbitrary_types_allowed=True):
Cleared at the end of every turn.
"""

STAGES = ["PRE_SERVICE", "PRE_TRANSITION", "CONDITION", "PRIORITY", "DESTINATION", "PRE_RESPONSE", "RESPONSE", "POST_SERVICE"]

class ExceptionInfo(BaseModel, arbitrary_types_allowed=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove exception info model, replace with Dict[Stages, List[Exception]], and defaultdict(list).

PRE_SERVICE: List[Exception] = Field(default_factory=list)
PRE_TRANSITION: List[Exception] = Field(default_factory=list)
CONDITION: List[Exception] = Field(default_factory=list)
PRIORITY: List[Exception] = Field(default_factory=list)
DESTINATION: List[Exception] = Field(default_factory=list)
TRANSITION: List[Exception] = Field(default_factory=list)
PRE_RESPONSE: List[Exception] = Field(default_factory=list)
RESPONSE: List[Exception] = Field(default_factory=list)
POST_SERVICE: List[Exception] = Field(default_factory=list)

class FrameworkData(BaseModel, arbitrary_types_allowed=True):
"""
Expand Down Expand Up @@ -74,6 +86,19 @@ class FrameworkData(BaseModel, arbitrary_types_allowed=True):
- no transition has been made during this turn yet (e.g. the turn is in the pre-transition step);
- no valid transition has been found (i.e. transitioned to fallback node).
"""
current_stage: Optional[Literal[*STAGES]] = Field(default=None, exclude=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can try to do Literal[*ExceptionInfo.model_fields.keys()] to avoid duplication in ExceptionInfo and STAGES.

"Stores current processing stage"
exception_info: Optional[ExceptionInfo] = Field(default=None, exclude=True)
"Stores exceptions raised at different stages"

def get_exception(self, stage: Optional[Literal[*STAGES]]) -> Optional[Tuple[str, ExceptionInfo]]:
if stage is None:
for stage in reversed(STAGES):
exception_list = getattr(self.exception_info, stage, [])
if exception_list:
return stage, exception_list[-1]
else:
return getattr(self.exception_info, stage, [])


class ContextMainInfo(BaseModel):
Expand Down
3 changes: 3 additions & 0 deletions chatsky/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,12 @@ async def _run_pipeline(
ctx.current_turn_id = ctx.current_turn_id + 1

ctx.requests[ctx.current_turn_id] = request
ctx.framework_data.current_stage = "PRE_SERVICE"

await self.services_pipeline(ctx)

ctx.framework_data.service_states.clear()
ctx.framework_data.current_stage = None
ctx.framework_data.pipeline = None

await ctx.store()
Expand Down
1 change: 1 addition & 0 deletions chatsky/core/script_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ async def wrapped_call(self, ctx: Context, *, info: str = ""):
return result
except Exception as exc:
logger.error(f"An exception occurred in {self.__class__.__name__}. {info}", exc_info=exc)
ctx.framework_data.exception_info[ctx.framework_data.current_stage].append(exc)
return exc

async def __call__(self, ctx: Context):
Expand Down
4 changes: 4 additions & 0 deletions chatsky/core/service/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def run_component(self, ctx: Context) -> None:
ctx.framework_data.current_node = ctx.pipeline.script.get_inherited_node(ctx.last_label)

logger.debug("Running pre_transition")
ctx.framework_data.current_stage = "PRE_TRANSITION"
await self._run_processing(ctx.current_node.pre_transition, ctx)

logger.debug("Running transitions")
Expand All @@ -79,6 +80,7 @@ async def run_component(self, ctx: Context) -> None:
ctx.framework_data.current_node = ctx.pipeline.script.get_inherited_node(next_label)

logger.debug("Running pre_response")
ctx.framework_data.current_stage = "PRE_RESPONSE"
await self._run_processing(ctx.current_node.pre_response, ctx)

node_response = ctx.current_node.response
Expand All @@ -87,6 +89,7 @@ async def run_component(self, ctx: Context) -> None:
if isinstance(response_result, Message):
response = response_result
logger.debug(f"Produced response {response}.")
ctx.framework_data.current_stage = "RESPONSE"
else:
logger.debug("Response was not produced.")
else:
Expand All @@ -95,6 +98,7 @@ async def run_component(self, ctx: Context) -> None:
logger.exception("Exception occurred during response processing.", exc_info=exc)

ctx.responses[ctx.current_turn_id] = response
ctx.framework_data.current_stage = "POST_SERVICE"

@staticmethod
async def _run_processing_parallel(processing: Dict[str, BaseProcessing], ctx: Context) -> None:
Expand Down
4 changes: 3 additions & 1 deletion chatsky/core/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ async def get_next_label(
2. The transition which lead to the next node or ``None`` if no transition is left by the end of the process.
"""
filtered_transitions: List[Transition] = transitions.copy()
ctx.framework_data.current_stage = "CONDITION"
condition_results = await asyncio.gather(*[transition.cnd.wrapped_call(ctx) for transition in filtered_transitions])

filtered_transitions = [
transition for transition, condition in zip(filtered_transitions, condition_results) if condition is True
]

ctx.framework_data.current_stage = "PRIORITY"
priority_results = await asyncio.gather(
*[transition.priority.wrapped_call(ctx) for transition in filtered_transitions]
)
Expand All @@ -94,7 +96,7 @@ async def get_next_label(
logger.debug(f"Possible transitions: {transitions_with_priorities!r}")

transitions_with_priorities = sorted(transitions_with_priorities, key=lambda x: x[1], reverse=True)

ctx.framework_data.current_stage = "DESTINATION"
destination_results: List[Union[AbsoluteNodeLabel, Exception]] = await asyncio.gather(
*[transition.dst.wrapped_call(ctx) for transition, _ in transitions_with_priorities]
)
Expand Down
Loading