Skip to content

Commit 9c4f32d

Browse files
authored
Update pipeline runner signature (#429)
Remove default value for `ctx_id`; Fix type annotations for `ctx_id`; Convert `ctx_id` to string in websocket tutorial.
1 parent 739bfaa commit 9c4f32d

File tree

10 files changed

+25
-30
lines changed

10 files changed

+25
-30
lines changed

chatsky/core/pipeline.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def validate_fallback_label(self):
233233
return self
234234

235235
async def _run_pipeline(
236-
self, request: Message, ctx_id: Optional[str] = None, update_ctx_misc: Optional[dict] = None
236+
self, request: Message, ctx_id: Optional[str], update_ctx_misc: Optional[dict] = None
237237
) -> Context:
238238
"""
239239
Method that should be invoked on user input.
@@ -292,9 +292,7 @@ def run(self):
292292
logger.info("Pipeline is accepting requests.")
293293
asyncio.run(self.messenger_interface.connect(self._run_pipeline))
294294

295-
def __call__(
296-
self, request: Message, ctx_id: Optional[str] = None, update_ctx_misc: Optional[dict] = None
297-
) -> Context:
295+
def __call__(self, request: Message, ctx_id: Optional[str], update_ctx_misc: Optional[dict] = None) -> Context:
298296
"""
299297
Method that executes pipeline once.
300298
Basically, it is a shortcut for :py:meth:`_run_pipeline`.

chatsky/core/service/types.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from __future__ import annotations
1111
from enum import unique, Enum
12-
from typing import Callable, Union, Awaitable, Optional, Any, Protocol, Hashable, TYPE_CHECKING
12+
from typing import Callable, Union, Awaitable, Optional, Any, Protocol, TYPE_CHECKING
1313
from typing_extensions import TypeAlias
1414
from pydantic import BaseModel
1515

@@ -25,13 +25,12 @@ class PipelineRunnerFunction(Protocol):
2525
"""
2626

2727
def __call__(
28-
self, message: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
28+
self, message: Message, ctx_id: Optional[str], update_ctx_misc: Optional[dict] = None
2929
) -> Awaitable[Context]:
3030
"""
3131
:param message: User request for pipeline to process.
3232
:param ctx_id:
3333
ID of the context that the new request belongs to.
34-
Optional, None by default.
3534
If set to `None`, a new context will be created with `message` being the first request.
3635
:param update_ctx_misc:
3736
Dictionary to be passed as an argument to `ctx.misc.update`.

chatsky/messengers/common/interface.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import logging
1212
from pathlib import Path
1313
from tempfile import gettempdir
14-
from typing import Optional, Any, List, Tuple, Hashable, TYPE_CHECKING, Type
14+
from typing import Optional, Any, List, Tuple, TYPE_CHECKING, Type
1515

1616
if TYPE_CHECKING:
1717
from chatsky.core import Context
@@ -99,7 +99,7 @@ class PollingMessengerInterface(MessengerInterface):
9999
"""
100100

101101
@abc.abstractmethod
102-
def _request(self) -> List[Tuple[Message, Hashable]]:
102+
def _request(self) -> List[Tuple[Message, str]]:
103103
"""
104104
Method used for sending users request for their input.
105105
@@ -181,17 +181,15 @@ async def connect(self, pipeline_runner: PipelineRunnerFunction):
181181
self._pipeline_runner = pipeline_runner
182182

183183
async def on_request_async(
184-
self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
184+
self, request: Message, ctx_id: Optional[str], update_ctx_misc: Optional[dict] = None
185185
) -> Context:
186186
"""
187187
Method that should be invoked on user input.
188188
This method has the same signature as :py:class:`~chatsky.core.service.types.PipelineRunnerFunction`.
189189
"""
190190
return await self._pipeline_runner(request, ctx_id, update_ctx_misc)
191191

192-
def on_request(
193-
self, request: Any, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
194-
) -> Context:
192+
def on_request(self, request: Any, ctx_id: Optional[str], update_ctx_misc: Optional[dict] = None) -> Context:
195193
"""
196194
Method that should be invoked on user input.
197195
This method has the same signature as :py:class:`~chatsky.core.service.types.PipelineRunnerFunction`.

chatsky/messengers/console.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Hashable, List, Optional, TextIO, Tuple
1+
from typing import Any, List, Optional, TextIO, Tuple
22
from uuid import uuid4
33
from chatsky.messengers.common.interface import PollingMessengerInterface
44
from chatsky.core.service.types import PipelineRunnerFunction
@@ -23,7 +23,7 @@ def __init__(
2323
out_descriptor: Optional[TextIO] = None,
2424
):
2525
super().__init__()
26-
self._ctx_id: Optional[Hashable] = None
26+
self._ctx_id: Optional[str] = None
2727
self._intro: Optional[str] = intro
2828
self._prompt_request: str = prompt_request
2929
self._prompt_response: str = prompt_response

tests/core/test_actor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ async def test_normal_execution(self):
3131
parallelize_processing=True,
3232
)
3333

34-
ctx = await pipeline._run_pipeline(Message())
34+
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")
3535

3636
assert ctx.labels._items == {
3737
0: AbsoluteNodeLabel(flow_name="flow", node_name="node1"),
@@ -49,7 +49,7 @@ async def test_fallback_node(self):
4949
parallelize_processing=True,
5050
)
5151

52-
ctx = await pipeline._run_pipeline(Message())
52+
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")
5353

5454
assert ctx.labels._items == {
5555
0: AbsoluteNodeLabel(flow_name="flow", node_name="node"),
@@ -85,7 +85,7 @@ async def test_default_priority(self, default_priority, result):
8585
default_priority=default_priority,
8686
)
8787

88-
ctx = await pipeline._run_pipeline(Message())
88+
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")
8989

9090
assert ctx.last_label.node_name == result
9191

@@ -105,7 +105,7 @@ async def call(self, ctx: Context) -> None:
105105
parallelize_processing=True,
106106
)
107107

108-
ctx = await pipeline._run_pipeline(Message())
108+
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")
109109

110110
assert ctx.last_label.node_name == "fallback"
111111
assert log_list[0].msg == "Exception occurred during transition processing."
@@ -122,7 +122,7 @@ async def test_empty_response(self, log_event_catcher):
122122
parallelize_processing=True,
123123
)
124124

125-
ctx = await pipeline._run_pipeline(Message())
125+
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")
126126

127127
assert ctx.responses == {1: Message()}
128128
assert log_list[-1].msg == "Node has empty response."
@@ -142,7 +142,7 @@ async def call(self, ctx: Context) -> MessageInitTypes:
142142
parallelize_processing=True,
143143
)
144144

145-
ctx = await pipeline._run_pipeline(Message())
145+
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")
146146

147147
assert ctx.responses == {1: Message()}
148148
assert log_list[-1].msg == "Response was not produced."
@@ -162,7 +162,7 @@ async def call(self, ctx: Context) -> None:
162162
parallelize_processing=True,
163163
)
164164

165-
ctx = await pipeline._run_pipeline(Message())
165+
ctx = await pipeline._run_pipeline(Message(), ctx_id="0")
166166

167167
assert ctx.responses == {1: Message()}
168168
assert log_list[0].msg == "Exception occurred during response processing."

tests/core/test_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def call(self, ctx: Context) -> MessageInitTypes:
125125
return ctx.pipeline.start_label.node_name
126126

127127
pipeline = Pipeline(script={"flow": {"node": {RESPONSE: MyResponse()}}}, start_label=("flow", "node"))
128-
ctx = await pipeline._run_pipeline(Message(text=""))
128+
ctx = await pipeline._run_pipeline(Message(text=""), ctx_id="0")
129129

130130
assert ctx.last_response == Message(text="node")
131131

@@ -145,7 +145,7 @@ async def call(self, ctx: Context) -> None:
145145
script={"flow": {"node": {PRE_RESPONSE: {"": MyProcessing()}, PRE_TRANSITION: {"": MyProcessing()}}}},
146146
start_label=("flow", "node"),
147147
)
148-
ctx = await pipeline._run_pipeline(Message(text=""))
148+
ctx = await pipeline._run_pipeline(Message(text=""), ctx_id="0")
149149
assert len(log) == 2
150150

151151
ctx.framework_data.current_node = None

tests/core/test_message.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pathlib import Path
33
from random import randint
44
from shutil import rmtree
5-
from typing import Hashable, Optional, TextIO
5+
from typing import Optional, TextIO
66
from urllib.request import urlopen
77

88
import pytest
@@ -48,7 +48,7 @@ class ChatskyCLIMessengerInterface(CLIMessengerInterface, MessengerInterfaceWith
4848

4949
def __init__(self, attachments_directory: Optional[Path] = None):
5050
MessengerInterfaceWithAttachments.__init__(self, attachments_directory)
51-
self._ctx_id: Optional[Hashable] = None
51+
self._ctx_id: Optional[str] = None
5252
self._intro: Optional[str] = None
5353
self._prompt_request: str = "request: "
5454
self._prompt_response: str = "response: "

tests/messengers/telegram/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from contextlib import contextmanager
33
from importlib import import_module
44
from hashlib import sha256
5-
from typing import Any, Dict, Hashable, Iterator, List, Optional, Tuple, Union
5+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
66

77
from pydantic import BaseModel
88
from telegram import InputFile, InputMedia, Update
@@ -90,7 +90,7 @@ def _wrap_pipeline_runner(self) -> Iterator[None]:
9090
original_pipeline_runner = self.interface._pipeline_runner
9191

9292
async def wrapped_pipeline_runner(
93-
message: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
93+
message: Message, ctx_id: Optional[str], update_ctx_misc: Optional[dict] = None
9494
) -> Context:
9595
self.latest_ctx = await original_pipeline_runner(message, ctx_id, update_ctx_misc)
9696
return self.latest_ctx

tests/pipeline/test_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def call(self, ctx: Context):
3434
pre_services=[MyProcessing(wait=0.02, text="A")],
3535
post_services=[MyProcessing(wait=0, text="C")],
3636
)
37-
await pipeline._run_pipeline(Message(""))
37+
await pipeline._run_pipeline(Message(""), ctx_id="0")
3838
assert logs == ["A", "B", "C"]
3939

4040

tutorials/messengers/web_api_interface/2_websocket_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
9494
await websocket.send_text(f"User: {data}")
9595
request = Message(data)
9696
context = await messenger_interface.on_request_async(
97-
request, client_id
97+
request, str(client_id)
9898
)
9999
response = context.last_response.text
100100
if response is not None:

0 commit comments

Comments
 (0)