1212from pynumaflow import setup_logging
1313from pynumaflow ._constants import MAX_MESSAGE_SIZE
1414from pynumaflow .proto .sourcetransformer import transform_pb2_grpc
15- from pynumaflow .sourcetransformer import Datum , Messages , Message
15+ from pynumaflow .sourcetransformer import Datum , Messages , Message , SourceTransformer
1616from pynumaflow .sourcetransformer .async_server import SourceTransformAsyncServer
1717from tests .sourcetransform .utils import get_test_datums
1818from tests .testing_utils import (
2626raise_error_from_st = False
2727
2828
29- async def async_transform_handler (keys : list [str ], datum : Datum ) -> Messages :
30- if raise_error_from_st :
31- raise ValueError ("Exception thrown from transform" )
32- val = datum .value
33- msg = "payload:{} event_time:{} " .format (
34- val .decode ("utf-8" ),
35- datum .event_time ,
36- )
37- val = bytes (msg , encoding = "utf-8" )
38- messages = Messages ()
39- messages .append (Message (val , mock_new_event_time (), keys = keys ))
40- return messages
29+ class TestAsyncSourceTrn (SourceTransformer ):
30+ async def handler (self , keys : list [str ], datum : Datum ) -> Messages :
31+ if raise_error_from_st :
32+ raise ValueError ("Exception thrown from transform" )
33+ val = datum .value
34+ msg = "payload:{} event_time:{} " .format (
35+ val .decode ("utf-8" ),
36+ datum .event_time ,
37+ )
38+ val = bytes (msg , encoding = "utf-8" )
39+ messages = Messages ()
40+ messages .append (Message (val , mock_new_event_time (), keys = keys ))
41+ return messages
4142
4243
4344def request_generator (req ):
@@ -55,7 +56,8 @@ def startup_callable(loop):
5556
5657
5758def new_async_st ():
58- server = SourceTransformAsyncServer (source_transform_instance = async_transform_handler )
59+ handle = TestAsyncSourceTrn ()
60+ server = SourceTransformAsyncServer (source_transform_instance = handle )
5961 udfs = server .servicer
6062 return udfs
6163
@@ -251,20 +253,17 @@ def __stub(self):
251253 return transform_pb2_grpc .SourceTransformStub (_channel )
252254
253255 def test_max_threads (self ):
256+ handle = TestAsyncSourceTrn ()
254257 # max cap at 16
255- server = SourceTransformAsyncServer (
256- source_transform_instance = async_transform_handler , max_threads = 32
257- )
258+ server = SourceTransformAsyncServer (source_transform_instance = handle , max_threads = 32 )
258259 self .assertEqual (server .max_threads , 16 )
259260
260261 # use argument provided
261- server = SourceTransformAsyncServer (
262- source_transform_instance = async_transform_handler , max_threads = 5
263- )
262+ server = SourceTransformAsyncServer (source_transform_instance = handle , max_threads = 5 )
264263 self .assertEqual (server .max_threads , 5 )
265264
266265 # defaults to 4
267- server = SourceTransformAsyncServer (source_transform_instance = async_transform_handler )
266+ server = SourceTransformAsyncServer (source_transform_instance = handle )
268267 self .assertEqual (server .max_threads , 4 )
269268
270269
0 commit comments