Skip to content

Commit 6473d4f

Browse files
committed
class test
Signed-off-by: Sidhant Kohli <sidhant.kohli@gmail.com>
1 parent 77953fa commit 6473d4f

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

tests/sourcetransform/test_async.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pynumaflow import setup_logging
1313
from pynumaflow._constants import MAX_MESSAGE_SIZE
1414
from pynumaflow.proto.sourcetransformer import transform_pb2_grpc
15-
from pynumaflow.sourcetransformer import Datum, Messages, Message
15+
from pynumaflow.sourcetransformer import Datum, Messages, Message, SourceTransformer
1616
from pynumaflow.sourcetransformer.async_server import SourceTransformAsyncServer
1717
from tests.sourcetransform.utils import get_test_datums
1818
from tests.testing_utils import (
@@ -26,18 +26,19 @@
2626
raise_error_from_st = False
2727

2828

29-
async def async_transform_handler(keys: list[str], datum: Datum) -> Messages:
30-
if raise_error_from_st:
31-
raise ValueError("Exception thrown from transform")
32-
val = datum.value
33-
msg = "payload:{} event_time:{} ".format(
34-
val.decode("utf-8"),
35-
datum.event_time,
36-
)
37-
val = bytes(msg, encoding="utf-8")
38-
messages = Messages()
39-
messages.append(Message(val, mock_new_event_time(), keys=keys))
40-
return messages
29+
class TestAsyncSourceTrn(SourceTransformer):
30+
async def handler(self, keys: list[str], datum: Datum) -> Messages:
31+
if raise_error_from_st:
32+
raise ValueError("Exception thrown from transform")
33+
val = datum.value
34+
msg = "payload:{} event_time:{} ".format(
35+
val.decode("utf-8"),
36+
datum.event_time,
37+
)
38+
val = bytes(msg, encoding="utf-8")
39+
messages = Messages()
40+
messages.append(Message(val, mock_new_event_time(), keys=keys))
41+
return messages
4142

4243

4344
def request_generator(req):
@@ -55,7 +56,8 @@ def startup_callable(loop):
5556

5657

5758
def new_async_st():
58-
server = SourceTransformAsyncServer(source_transform_instance=async_transform_handler)
59+
handle = TestAsyncSourceTrn()
60+
server = SourceTransformAsyncServer(source_transform_instance=handle)
5961
udfs = server.servicer
6062
return udfs
6163

@@ -251,20 +253,17 @@ def __stub(self):
251253
return transform_pb2_grpc.SourceTransformStub(_channel)
252254

253255
def test_max_threads(self):
256+
handle = TestAsyncSourceTrn()
254257
# max cap at 16
255-
server = SourceTransformAsyncServer(
256-
source_transform_instance=async_transform_handler, max_threads=32
257-
)
258+
server = SourceTransformAsyncServer(source_transform_instance=handle, max_threads=32)
258259
self.assertEqual(server.max_threads, 16)
259260

260261
# use argument provided
261-
server = SourceTransformAsyncServer(
262-
source_transform_instance=async_transform_handler, max_threads=5
263-
)
262+
server = SourceTransformAsyncServer(source_transform_instance=handle, max_threads=5)
264263
self.assertEqual(server.max_threads, 5)
265264

266265
# defaults to 4
267-
server = SourceTransformAsyncServer(source_transform_instance=async_transform_handler)
266+
server = SourceTransformAsyncServer(source_transform_instance=handle)
268267
self.assertEqual(server.max_threads, 4)
269268

270269

0 commit comments

Comments
 (0)