Skip to content

Commit 5730097

Browse files
N-giveNathan Givens
andauthored
IWF-466 add state_execution_locals and record_events (#87)
* IWF-466 add state_execution_locals and record_events * IWF-466 formatting * IWF-466 remove debug printing * IWF-466 improve readability --------- Co-authored-by: Nathan Givens <[email protected]>
1 parent 343d932 commit 5730097

File tree

4 files changed

+200
-5
lines changed

4 files changed

+200
-5
lines changed

iwf/persistence.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
1-
from typing import Any, Union
1+
from typing import Any, Tuple, Union
22

33
from iwf.data_attributes import DataAttributes
44
from iwf.search_attributes import SearchAttributes
5+
from iwf.state_execution_locals import StateExecutionLocals
56

67

78
class Persistence:
89
_data_attributes: DataAttributes
910
_search_attributes: SearchAttributes
11+
_state_execution_locals: StateExecutionLocals
1012

1113
def __init__(
1214
self,
1315
data_attributes: DataAttributes,
1416
search_attributes: SearchAttributes,
17+
state_execution_locals: StateExecutionLocals,
1518
):
1619
self._data_attributes = data_attributes
1720
self._search_attributes = search_attributes
21+
self._state_execution_locals = state_execution_locals
1822

1923
def get_data_attribute(self, key: str) -> Any:
2024
return self._data_attributes.get_data_attribute(key)
@@ -65,3 +69,12 @@ def set_search_attribute_keyword_array(
6569
self, key: str, value: Union[None, list[str]]
6670
):
6771
self._search_attributes.set_search_attribute_keyword_array(key, value)
72+
73+
def get_state_execution_local(self, key: str) -> Any:
74+
return self._state_execution_locals.get_state_execution_local(key)
75+
76+
def set_state_execution_local(self, key: str, value: Any):
77+
self._state_execution_locals.set_state_execution_local(key, value)
78+
79+
def record_event(self, key: str, *event_data: Tuple[Any, ...]):
80+
self._state_execution_locals.record_event(key, *event_data)

iwf/state_execution_locals.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from typing import Any, List, Tuple
2+
3+
from iwf.errors import WorkflowDefinitionError
4+
from iwf.iwf_api.models import EncodedObject, KeyValue
5+
from iwf.object_encoder import ObjectEncoder
6+
7+
8+
class StateExecutionLocals:
9+
_record_events: dict[str, EncodedObject]
10+
_attribute_name_to_encoded_object_map: dict[str, EncodedObject]
11+
_upsert_attributes_to_return_to_server: dict[str, EncodedObject]
12+
_object_encoder: ObjectEncoder
13+
14+
def __init__(
15+
self,
16+
attribute_name_to_encoded_object_map: dict[str, EncodedObject],
17+
object_encoder: ObjectEncoder,
18+
):
19+
self._object_encoder = object_encoder
20+
self._attribute_name_to_encoded_object_map = (
21+
attribute_name_to_encoded_object_map
22+
)
23+
self._upsert_attributes_to_return_to_server = {}
24+
self._record_events = {}
25+
26+
def set_state_execution_local(self, key: str, value: Any):
27+
encoded_data = self._object_encoder.encode(value)
28+
self._attribute_name_to_encoded_object_map[key] = encoded_data
29+
self._upsert_attributes_to_return_to_server[key] = encoded_data
30+
31+
def get_state_execution_local(self, key: str) -> Any:
32+
encoded_object = self._attribute_name_to_encoded_object_map.get(key)
33+
if encoded_object is None:
34+
return None
35+
return self._object_encoder.decode(encoded_object)
36+
37+
def record_event(self, key: str, *event_data: Tuple[Any, ...]):
38+
if key in self._record_events:
39+
raise WorkflowDefinitionError("Cannot record the same event more than once")
40+
41+
if event_data is not None and len(event_data) == 1:
42+
self._record_events[key] = self._object_encoder.encode(event_data[0])
43+
44+
self._record_events[key] = self._object_encoder.encode(event_data)
45+
46+
def get_upsert_state_execution_local_attributes(self) -> List[KeyValue]:
47+
return [
48+
KeyValue(item_key, item_value)
49+
for item_key, item_value in self._upsert_attributes_to_return_to_server.items()
50+
]
51+
52+
def get_record_events(self) -> List[KeyValue]:
53+
return [
54+
KeyValue(item_key, item_value)
55+
for item_key, item_value in self._record_events.items()
56+
]
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import inspect
2+
import time
3+
import unittest
4+
5+
from iwf.rpc import rpc
6+
from iwf.client import Client
7+
from iwf.command_request import CommandRequest
8+
from iwf.command_results import CommandResults
9+
from iwf.communication import Communication
10+
from iwf.persistence import Persistence
11+
from iwf.persistence_schema import PersistenceField, PersistenceSchema
12+
from iwf.state_decision import StateDecision
13+
from iwf.state_schema import StateSchema
14+
from iwf.tests.worker_server import registry
15+
from iwf.workflow import ObjectWorkflow
16+
from iwf.workflow_context import WorkflowContext
17+
from iwf.workflow_options import WorkflowOptions
18+
from iwf.workflow_state import T, WorkflowState
19+
20+
PERSISTENCE_LOCAL_KEY = "persistence-test-key"
21+
PERSISTENCE_LOCAL_VALUE = "persistence-test-value"
22+
PERSISTENCE_DATA_ATTRIBUTE_KEY = "persistence-data-attribute-key"
23+
24+
25+
class PersistenceStateExecutionLocalRWState(WorkflowState[None]):
26+
def wait_until(
27+
self,
28+
ctx: WorkflowContext,
29+
input: T,
30+
persistence: Persistence,
31+
communication: Communication,
32+
):
33+
persistence.set_state_execution_local(
34+
PERSISTENCE_LOCAL_KEY, PERSISTENCE_LOCAL_VALUE
35+
)
36+
return CommandRequest.empty()
37+
38+
def execute(
39+
self,
40+
ctx: WorkflowContext,
41+
input: T,
42+
command_results: CommandResults,
43+
persistence: Persistence,
44+
communication: Communication,
45+
):
46+
value = persistence.get_state_execution_local(PERSISTENCE_LOCAL_KEY)
47+
persistence.set_data_attribute(PERSISTENCE_DATA_ATTRIBUTE_KEY, value)
48+
return StateDecision.graceful_complete_workflow()
49+
50+
51+
class PersistenceStateExecutionLocalWorkflow(ObjectWorkflow):
52+
def get_workflow_states(self) -> StateSchema:
53+
return StateSchema.with_starting_state(PersistenceStateExecutionLocalRWState())
54+
55+
def get_persistence_schema(self) -> PersistenceSchema:
56+
return PersistenceSchema.create(
57+
PersistenceField.data_attribute_def(PERSISTENCE_DATA_ATTRIBUTE_KEY, str)
58+
)
59+
60+
@rpc()
61+
def test_persistence_read(self, persistence: Persistence):
62+
return persistence.get_data_attribute(PERSISTENCE_DATA_ATTRIBUTE_KEY)
63+
64+
65+
class TestPersistenceExecutionLocalRead(unittest.TestCase):
66+
@classmethod
67+
def setUpClass(cls):
68+
wf = PersistenceStateExecutionLocalWorkflow()
69+
registry.add_workflow(wf)
70+
cls.client = Client(registry)
71+
72+
def test_persistence_execution_local_workflow(self):
73+
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"
74+
start_options = WorkflowOptions()
75+
self.client.start_workflow(
76+
PersistenceStateExecutionLocalWorkflow, wf_id, 200, None, start_options
77+
)
78+
self.client.wait_for_workflow_completion(wf_id, None)
79+
res = self.client.invoke_rpc(
80+
wf_id, PersistenceStateExecutionLocalWorkflow.test_persistence_read
81+
)
82+
assert res == PERSISTENCE_LOCAL_VALUE

iwf/worker_service.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import traceback
22
import typing
33
from dataclasses import dataclass
4-
from typing import Union
4+
from typing import List, Union
55

66
from iwf.command_request import _to_idl_command_request
77
from iwf.command_results import from_idl_command_results
@@ -25,6 +25,7 @@
2525
from iwf.registry import Registry
2626
from iwf.search_attributes import SearchAttributes
2727
from iwf.state_decision import StateDecision, _to_idl_state_decision
28+
from iwf.state_execution_locals import StateExecutionLocals
2829
from iwf.utils.iwf_typing import assert_not_unset, unset_to_none
2930
from iwf.workflow_context import WorkflowContext, _from_idl_context
3031
from iwf.workflow_state import get_input_type
@@ -99,8 +100,13 @@ def handle_workflow_worker_rpc(
99100
search_attributes = SearchAttributes(
100101
search_attributes_types, unset_to_none(request.search_attributes)
101102
)
103+
state_execution_locals = StateExecutionLocals(
104+
to_map(None), self._options.object_encoder
105+
)
102106

103-
persistence = Persistence(data_attributes, search_attributes)
107+
persistence = Persistence(
108+
data_attributes, search_attributes, state_execution_locals
109+
)
104110

105111
communication = Communication(
106112
internal_channel_types,
@@ -145,6 +151,9 @@ def handle_workflow_worker_rpc(
145151
)
146152
if upsert_sas:
147153
response.upsert_search_attributes = upsert_sas
154+
record_events = state_execution_locals.get_record_events()
155+
if len(record_events) > 0:
156+
response.record_events = record_events
148157
if len(communication.get_to_trigger_state_movements()) > 0:
149158
movements = communication.get_to_trigger_state_movements()
150159
decision = StateDecision.multi_next_states(*movements)
@@ -188,8 +197,13 @@ def handle_workflow_state_wait_until(
188197
search_attributes = SearchAttributes(
189198
search_attributes_types, unset_to_none(request.search_attributes)
190199
)
200+
state_execution_locals = StateExecutionLocals(
201+
to_map(None), self._options.object_encoder
202+
)
191203

192-
persistence = Persistence(data_attributes, search_attributes)
204+
persistence = Persistence(
205+
data_attributes, search_attributes, state_execution_locals
206+
)
193207

194208
communication = Communication(
195209
internal_channel_types,
@@ -211,13 +225,20 @@ def handle_workflow_state_wait_until(
211225
search_attributes.get_upsert_to_server_string_array_attribute_map(),
212226
)
213227

228+
upsert_state_locals = (
229+
state_execution_locals.get_upsert_state_execution_local_attributes()
230+
)
231+
record_events = state_execution_locals.get_record_events()
232+
214233
response = WorkflowStateWaitUntilResponse(
215234
command_request=_to_idl_command_request(command_request),
216235
publish_to_inter_state_channel=pubs,
217236
upsert_data_objects=[
218237
KeyValue(k, v)
219238
for (k, v) in data_attributes.get_updated_values_to_return().items()
220239
],
240+
upsert_state_locals=upsert_state_locals,
241+
record_events=record_events,
221242
)
222243

223244
if upsert_sas:
@@ -257,8 +278,13 @@ def handle_workflow_state_execute(
257278
search_attributes = SearchAttributes(
258279
search_attributes_types, unset_to_none(request.search_attributes)
259280
)
281+
state_execution_locals = StateExecutionLocals(
282+
to_map(request.state_locals), self._options.object_encoder
283+
)
260284

261-
persistence = Persistence(data_attributes, search_attributes)
285+
persistence = Persistence(
286+
data_attributes, search_attributes, state_execution_locals
287+
)
262288

263289
communication = Communication(
264290
internal_channel_types,
@@ -288,6 +314,10 @@ def handle_workflow_state_execute(
288314
search_attributes.get_upsert_to_server_double_attribute_map(),
289315
search_attributes.get_upsert_to_server_string_array_attribute_map(),
290316
)
317+
upsert_state_locals = (
318+
state_execution_locals.get_upsert_state_execution_local_attributes()
319+
)
320+
record_events = state_execution_locals.get_record_events()
291321

292322
response = WorkflowStateExecuteResponse(
293323
state_decision=_to_idl_state_decision(
@@ -301,6 +331,8 @@ def handle_workflow_state_execute(
301331
KeyValue(k, v)
302332
for (k, v) in data_attributes.get_updated_values_to_return().items()
303333
],
334+
upsert_state_locals=upsert_state_locals,
335+
record_events=record_events,
304336
)
305337

306338
if upsert_sas:
@@ -367,3 +399,15 @@ def _create_upsert_search_attributes(
367399
sas.append(sa)
368400

369401
return sas
402+
403+
404+
def to_map(key_values: Union[None, Unset, List[KeyValue]]) -> dict[str, EncodedObject]:
405+
key_values = unset_to_none(key_values) or []
406+
kvs = {}
407+
for kv in key_values:
408+
k = unset_to_none(kv.key)
409+
v = unset_to_none(kv.value)
410+
if k and v:
411+
kvs[k] = v
412+
413+
return kvs

0 commit comments

Comments
 (0)