Skip to content

Commit a5a53b3

Browse files
Fix rpc input deserialization (#76)
* a fix to the rpc input * Fix rpc input deserialization * output type assert * fix lint
1 parent c4078ab commit a5a53b3

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

iwf/rpc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def wrapper(*args, **kwargs):
6969
if v.annotation not in valid_param_types_exclude_input:
7070
if not has_input:
7171
has_input = True
72+
rpc_info.input_type = v.annotation
7273
else:
7374
raise rpc_definition_err
7475
if not need_persistence:

iwf/tests/test_rpc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
import time
33
import unittest
4+
from dataclasses import dataclass
45

56
from iwf.client import Client
67
from iwf.command_request import CommandRequest, InternalChannelCommand
@@ -22,6 +23,12 @@
2223
idle_channel_name = "test-3"
2324

2425

26+
@dataclass
27+
class Mydata:
28+
strdata: str
29+
intdata: int
30+
31+
2532
class WaitState(WorkflowState[None]):
2633
def wait_until(
2734
self,
@@ -102,6 +109,12 @@ def test_rpc_publish_to_idle_channel(self, com: Communication, data: str):
102109
com.publish_to_internal_channel(idle_channel_name, data)
103110
return com.get_internal_channel_size(idle_channel_name)
104111

112+
@rpc()
113+
def test_rpc_input_type(self, input: Mydata) -> Mydata:
114+
if input.intdata != 100 or input.strdata != "test":
115+
raise Exception("input type test failed")
116+
return input
117+
105118

106119
class TestRPCs(unittest.TestCase):
107120
@classmethod
@@ -113,6 +126,13 @@ def setUpClass(cls):
113126
def test_simple_rpc(self):
114127
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"
115128
self.client.start_workflow(RPCWorkflow, wf_id, 10)
129+
130+
input = Mydata("test", 100)
131+
output = self.client.invoke_rpc(
132+
wf_id, RPCWorkflow.test_rpc_input_type, input, Mydata
133+
)
134+
assert output == input
135+
116136
output = self.client.invoke_rpc(wf_id, RPCWorkflow.test_simple_rpc)
117137
assert output == 123
118138
wf = RPCWorkflow()

0 commit comments

Comments
 (0)