Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions chatsky/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def validate_fallback_label(self):
return self

async def _run_pipeline(
self, request: Message, ctx_id: Optional[str] = None, update_ctx_misc: Optional[dict] = None
self, request: Message, ctx_id: Optional[str], update_ctx_misc: Optional[dict] = None
) -> Context:
"""
Method that should be invoked on user input.
Expand Down Expand Up @@ -292,9 +292,7 @@ def run(self):
logger.info("Pipeline is accepting requests.")
asyncio.run(self.messenger_interface.connect(self._run_pipeline))

def __call__(
self, request: Message, ctx_id: Optional[str] = None, update_ctx_misc: Optional[dict] = None
) -> Context:
def __call__(self, request: Message, ctx_id: Optional[str], update_ctx_misc: Optional[dict] = None) -> Context:
"""
Method that executes pipeline once.
Basically, it is a shortcut for :py:meth:`_run_pipeline`.
Expand Down
5 changes: 2 additions & 3 deletions chatsky/core/service/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from __future__ import annotations
from enum import unique, Enum
from typing import Callable, Union, Awaitable, Optional, Any, Protocol, Hashable, TYPE_CHECKING
from typing import Callable, Union, Awaitable, Optional, Any, Protocol, TYPE_CHECKING
from typing_extensions import TypeAlias
from pydantic import BaseModel

Expand All @@ -25,13 +25,12 @@ class PipelineRunnerFunction(Protocol):
"""

def __call__(
self, message: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
self, message: Message, ctx_id: Optional[str], update_ctx_misc: Optional[dict] = None
) -> Awaitable[Context]:
"""
:param message: User request for pipeline to process.
:param ctx_id:
ID of the context that the new request belongs to.
Optional, None by default.
If set to `None`, a new context will be created with `message` being the first request.
:param update_ctx_misc:
Dictionary to be passed as an argument to `ctx.misc.update`.
Expand Down
10 changes: 4 additions & 6 deletions chatsky/messengers/common/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
from pathlib import Path
from tempfile import gettempdir
from typing import Optional, Any, List, Tuple, Hashable, TYPE_CHECKING, Type
from typing import Optional, Any, List, Tuple, TYPE_CHECKING, Type

if TYPE_CHECKING:
from chatsky.core import Context
Expand Down Expand Up @@ -99,7 +99,7 @@ class PollingMessengerInterface(MessengerInterface):
"""

@abc.abstractmethod
def _request(self) -> List[Tuple[Message, Hashable]]:
def _request(self) -> List[Tuple[Message, str]]:
"""
Method used for sending users request for their input.

Expand Down Expand Up @@ -181,17 +181,15 @@ async def connect(self, pipeline_runner: PipelineRunnerFunction):
self._pipeline_runner = pipeline_runner

async def on_request_async(
self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
self, request: Message, ctx_id: Optional[str], update_ctx_misc: Optional[dict] = None
) -> Context:
"""
Method that should be invoked on user input.
This method has the same signature as :py:class:`~chatsky.core.service.types.PipelineRunnerFunction`.
"""
return await self._pipeline_runner(request, ctx_id, update_ctx_misc)

def on_request(
self, request: Any, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
) -> Context:
def on_request(self, request: Any, ctx_id: Optional[str], update_ctx_misc: Optional[dict] = None) -> Context:
"""
Method that should be invoked on user input.
This method has the same signature as :py:class:`~chatsky.core.service.types.PipelineRunnerFunction`.
Expand Down
4 changes: 2 additions & 2 deletions chatsky/messengers/console.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Hashable, List, Optional, TextIO, Tuple
from typing import Any, List, Optional, TextIO, Tuple
from uuid import uuid4
from chatsky.messengers.common.interface import PollingMessengerInterface
from chatsky.core.service.types import PipelineRunnerFunction
Expand All @@ -23,7 +23,7 @@ def __init__(
out_descriptor: Optional[TextIO] = None,
):
super().__init__()
self._ctx_id: Optional[Hashable] = None
self._ctx_id: Optional[str] = None
self._intro: Optional[str] = intro
self._prompt_request: str = prompt_request
self._prompt_response: str = prompt_response
Expand Down
14 changes: 7 additions & 7 deletions tests/core/test_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_normal_execution(self):
parallelize_processing=True,
)

ctx = await pipeline._run_pipeline(Message())
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")

assert ctx.labels._items == {
0: AbsoluteNodeLabel(flow_name="flow", node_name="node1"),
Expand All @@ -49,7 +49,7 @@ async def test_fallback_node(self):
parallelize_processing=True,
)

ctx = await pipeline._run_pipeline(Message())
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")

assert ctx.labels._items == {
0: AbsoluteNodeLabel(flow_name="flow", node_name="node"),
Expand Down Expand Up @@ -85,7 +85,7 @@ async def test_default_priority(self, default_priority, result):
default_priority=default_priority,
)

ctx = await pipeline._run_pipeline(Message())
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")

assert ctx.last_label.node_name == result

Expand All @@ -105,7 +105,7 @@ async def call(self, ctx: Context) -> None:
parallelize_processing=True,
)

ctx = await pipeline._run_pipeline(Message())
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")

assert ctx.last_label.node_name == "fallback"
assert log_list[0].msg == "Exception occurred during transition processing."
Expand All @@ -122,7 +122,7 @@ async def test_empty_response(self, log_event_catcher):
parallelize_processing=True,
)

ctx = await pipeline._run_pipeline(Message())
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")

assert ctx.responses == {1: Message()}
assert log_list[-1].msg == "Node has empty response."
Expand All @@ -142,7 +142,7 @@ async def call(self, ctx: Context) -> MessageInitTypes:
parallelize_processing=True,
)

ctx = await pipeline._run_pipeline(Message())
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")

assert ctx.responses == {1: Message()}
assert log_list[-1].msg == "Response was not produced."
Expand All @@ -162,7 +162,7 @@ async def call(self, ctx: Context) -> None:
parallelize_processing=True,
)

ctx = await pipeline._run_pipeline(Message())
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")

assert ctx.responses == {1: Message()}
assert log_list[0].msg == "Exception occurred during response processing."
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def call(self, ctx: Context) -> MessageInitTypes:
return ctx.pipeline.start_label.node_name

pipeline = Pipeline(script={"flow": {"node": {RESPONSE: MyResponse()}}}, start_label=("flow", "node"))
ctx = await pipeline._run_pipeline(Message(text=""))
ctx = await pipeline._run_pipeline(Message(text=""), ctx_id="0")

assert ctx.last_response == Message(text="node")

Expand All @@ -145,7 +145,7 @@ async def call(self, ctx: Context) -> None:
script={"flow": {"node": {PRE_RESPONSE: {"": MyProcessing()}, PRE_TRANSITION: {"": MyProcessing()}}}},
start_label=("flow", "node"),
)
ctx = await pipeline._run_pipeline(Message(text=""))
ctx = await pipeline._run_pipeline(Message(text=""), ctx_id="0")
assert len(log) == 2

ctx.framework_data.current_node = None
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
from random import randint
from shutil import rmtree
from typing import Hashable, Optional, TextIO
from typing import Optional, TextIO
from urllib.request import urlopen

import pytest
Expand Down Expand Up @@ -48,7 +48,7 @@ class ChatskyCLIMessengerInterface(CLIMessengerInterface, MessengerInterfaceWith

def __init__(self, attachments_directory: Optional[Path] = None):
MessengerInterfaceWithAttachments.__init__(self, attachments_directory)
self._ctx_id: Optional[Hashable] = None
self._ctx_id: Optional[str] = None
self._intro: Optional[str] = None
self._prompt_request: str = "request: "
self._prompt_response: str = "response: "
Expand Down
4 changes: 2 additions & 2 deletions tests/messengers/telegram/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from contextlib import contextmanager
from importlib import import_module
from hashlib import sha256
from typing import Any, Dict, Hashable, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

from pydantic import BaseModel
from telegram import InputFile, InputMedia, Update
Expand Down Expand Up @@ -90,7 +90,7 @@ def _wrap_pipeline_runner(self) -> Iterator[None]:
original_pipeline_runner = self.interface._pipeline_runner

async def wrapped_pipeline_runner(
message: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
message: Message, ctx_id: Optional[str], update_ctx_misc: Optional[dict] = None
) -> Context:
self.latest_ctx = await original_pipeline_runner(message, ctx_id, update_ctx_misc)
return self.latest_ctx
Expand Down
2 changes: 1 addition & 1 deletion tests/pipeline/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def call(self, ctx: Context):
pre_services=[MyProcessing(wait=0.02, text="A")],
post_services=[MyProcessing(wait=0, text="C")],
)
await pipeline._run_pipeline(Message(""))
await pipeline._run_pipeline(Message(""), ctx_id="0")
assert logs == ["A", "B", "C"]


Expand Down
2 changes: 1 addition & 1 deletion tutorials/messengers/web_api_interface/2_websocket_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
await websocket.send_text(f"User: {data}")
request = Message(data)
context = await messenger_interface.on_request_async(
request, client_id
request, str(client_id)
)
response = context.last_response.text
if response is not None:
Expand Down
Loading