Skip to content

Commit d68a706

Browse files
N-giveNathan Givens
andauthored
IWF-348 define data_attributes by prefix (#89)
Co-authored-by: Nathan Givens <[email protected]>
1 parent 5730097 commit d68a706

File tree

5 files changed

+77
-10
lines changed

5 files changed

+77
-10
lines changed

iwf/data_attributes.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,40 @@
77

88
class DataAttributes:
99
_type_store: dict[str, Optional[type]]
10+
_prefix_type_store: dict[str, Optional[type]]
1011
_object_encoder: ObjectEncoder
1112
_current_values: dict[str, Union[EncodedObject, None]]
1213
_updated_values_to_return: dict[str, EncodedObject]
1314

1415
def __init__(
1516
self,
1617
type_store: dict[str, Optional[type]],
18+
prefix_type_store: dict[str, Optional[type]],
1719
object_encoder: ObjectEncoder,
1820
current_values: dict[str, Union[EncodedObject, None]],
1921
):
2022
self._object_encoder = object_encoder
2123
self._type_store = type_store
24+
self._prefix_type_store = prefix_type_store
2225
self._current_values = current_values
2326
self._updated_values_to_return = {}
2427

2528
def get_data_attribute(self, key: str) -> Any:
26-
if key not in self._type_store:
29+
is_registered, registered_type = self._validate_key_and_get_type(key)
30+
if not is_registered:
2731
raise WorkflowDefinitionError(f"data attribute %s is not registered {key}")
2832

2933
encoded_object = self._current_values.get(key)
3034
if encoded_object is None:
3135
return None
3236

33-
registered_type = self._type_store[key]
3437
return self._object_encoder.decode(encoded_object, registered_type)
3538

3639
def set_data_attribute(self, key: str, value: Any):
37-
if key not in self._type_store:
40+
is_registered, registered_type = self._validate_key_and_get_type(key)
41+
if not is_registered:
3842
raise WorkflowDefinitionError(f"data attribute %s is not registered {key}")
3943

40-
registered_type = self._type_store[key]
4144
if registered_type is not None and not isinstance(value, registered_type):
4245
raise WorkflowDefinitionError(
4346
f"data attribute %s is of the right type {registered_type}"
@@ -49,3 +52,13 @@ def set_data_attribute(self, key: str, value: Any):
4952

5053
def get_updated_values_to_return(self) -> dict[str, EncodedObject]:
5154
return self._updated_values_to_return
55+
56+
def _validate_key_and_get_type(self, key) -> tuple[bool, Optional[type]]:
57+
if key in self._type_store:
58+
return (True, self._type_store.get(key))
59+
60+
for prefix, t in self._prefix_type_store.items():
61+
if key.startswith(prefix):
62+
return (True, t)
63+
64+
return (False, None)

iwf/persistence_schema.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
class PersistenceFieldType(Enum):
99
DataAttribute = 1
1010
SearchAttribute = 2
11+
DataAttributePrefix = 3
1112

1213

1314
@dataclass
@@ -27,6 +28,12 @@ def search_attribute_def(cls, key: str, sa_type: SearchAttributeValueType):
2728
key, PersistenceFieldType.SearchAttribute, None, sa_type
2829
)
2930

31+
@classmethod
32+
def data_attribute_prefix_def(cls, key: str, value_type: Optional[type]):
33+
return PersistenceField(
34+
key, PersistenceFieldType.DataAttributePrefix, value_type
35+
)
36+
3037

3138
@dataclass
3239
class PersistenceSchema:

iwf/registry.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class Registry:
1717
_internal_channel_type_store: dict[str, TypeStore]
1818
_signal_channel_type_store: dict[str, dict[str, Optional[type]]]
1919
_data_attribute_types: dict[str, dict[str, Optional[type]]]
20+
_data_attribute_prefix_types: dict[str, dict[str, Optional[type]]]
2021
_search_attribute_types: dict[str, dict[str, SearchAttributeValueType]]
2122
_rpc_infos: dict[str, dict[str, RPCInfo]]
2223

@@ -27,6 +28,7 @@ def __init__(self):
2728
self._internal_channel_type_store = dict()
2829
self._signal_channel_type_store = dict()
2930
self._data_attribute_types = dict()
31+
self._data_attribute_prefix_types = dict()
3032
self._search_attribute_types = {}
3133
self._rpc_infos = dict()
3234

@@ -36,6 +38,7 @@ def add_workflow(self, wf: ObjectWorkflow):
3638
self._register_internal_channels(wf)
3739
self._register_signal_channels(wf)
3840
self._register_data_attributes(wf)
41+
self._register_data_attribute_prefix_types(wf)
3942
self._register_search_attributes(wf)
4043
self._register_workflow_rpcs(wf)
4144

@@ -82,6 +85,11 @@ def get_search_attribute_types(
8285
) -> dict[str, SearchAttributeValueType]:
8386
return self._search_attribute_types[wf_type]
8487

88+
def get_data_attribute_prefix_types(
89+
self, wf_type: str
90+
) -> dict[str, Optional[type]]:
91+
return self._data_attribute_prefix_types[wf_type]
92+
8593
def get_rpc_infos(self, wf_type: str) -> dict[str, RPCInfo]:
8694
return self._rpc_infos[wf_type]
8795

@@ -115,11 +123,11 @@ def _register_signal_channels(self, wf: ObjectWorkflow):
115123

116124
def _register_data_attributes(self, wf: ObjectWorkflow):
117125
wf_type = get_workflow_type(wf)
118-
types: dict[str, Optional[type]] = {}
126+
data_attribute_types: dict[str, Optional[type]] = {}
119127
for field in wf.get_persistence_schema().persistence_fields:
120128
if field.field_type == PersistenceFieldType.DataAttribute:
121-
types[field.key] = field.value_type
122-
self._data_attribute_types[wf_type] = types
129+
data_attribute_types[field.key] = field.value_type
130+
self._data_attribute_types[wf_type] = data_attribute_types
123131

124132
def _register_search_attributes(self, wf: ObjectWorkflow):
125133
wf_type = get_workflow_type(wf)
@@ -138,6 +146,14 @@ def _register_search_attributes(self, wf: ObjectWorkflow):
138146
types[field.key] = sa_type
139147
self._search_attribute_types[wf_type] = types
140148

149+
def _register_data_attribute_prefix_types(self, wf: ObjectWorkflow):
150+
wf_type = get_workflow_type(wf)
151+
data_attribute_prefix_types: dict[str, Optional[type]] = {}
152+
for field in wf.get_persistence_schema().persistence_fields:
153+
if field.field_type == PersistenceFieldType.DataAttributePrefix:
154+
data_attribute_prefix_types[field.key] = field.value_type
155+
self._data_attribute_prefix_types[wf_type] = data_attribute_prefix_types
156+
141157
def _register_workflow_state(self, wf):
142158
wf_type = get_workflow_type(wf)
143159
state_map = {}

iwf/tests/test_persistence_data_attributes.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
final_initial_da_value_1 = initial_da_value_1
3131
final_initial_da_value_2 = "no-more-init"
3232

33+
test_da_prefix = "test-da-prefix"
34+
test_da_prefix_key_1 = "test-da-prefix-1"
35+
test_da_prefix_key_2 = "test-da-prefix-2"
36+
test_da_prefix_value_1 = "test-da-value-1"
37+
test_da_prefix_value_2 = "test-da-value-2"
38+
3339

3440
class DataAttributeRWState(WorkflowState[None]):
3541
def wait_until(
@@ -60,6 +66,8 @@ def execute(
6066
persistence.set_data_attribute(test_da_1, final_test_da_value_1)
6167
persistence.set_data_attribute(test_da_2, final_test_da_value_2)
6268
persistence.set_data_attribute(initial_da_2, final_initial_da_value_2)
69+
persistence.set_data_attribute(test_da_prefix_key_1, test_da_prefix_value_1)
70+
persistence.set_data_attribute(test_da_prefix_key_2, test_da_prefix_value_2)
6371
return StateDecision.graceful_complete_workflow()
6472

6573

@@ -73,6 +81,7 @@ def get_persistence_schema(self) -> PersistenceSchema:
7381
PersistenceField.data_attribute_def(initial_da_2, str),
7482
PersistenceField.data_attribute_def(test_da_1, str),
7583
PersistenceField.data_attribute_def(test_da_2, int),
84+
PersistenceField.data_attribute_prefix_def(test_da_prefix, str),
7685
)
7786

7887
@rpc()
@@ -82,6 +91,8 @@ def test_persistence_read(self, pers: Persistence):
8291
pers.get_data_attribute(initial_da_2),
8392
pers.get_data_attribute(test_da_1),
8493
pers.get_data_attribute(test_da_2),
94+
pers.get_data_attribute(test_da_prefix_key_1),
95+
pers.get_data_attribute(test_da_prefix_key_2),
8596
)
8697

8798

@@ -115,4 +126,6 @@ def test_persistence_data_attributes_workflow(self):
115126
final_initial_da_value_2,
116127
final_test_da_value_1,
117128
final_test_da_value_2,
129+
test_da_prefix_value_1,
130+
test_da_prefix_value_2,
118131
]

iwf/worker_service.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def handle_workflow_worker_rpc(
7979
internal_channel_types = self._registry.get_internal_channel_type_store(wf_type)
8080
signal_channel_types = self._registry.get_signal_channel_types(wf_type)
8181
data_attributes_types = self._registry.get_data_attribute_types(wf_type)
82+
data_attributes_prefix_types = self._registry.get_data_attribute_prefix_types(
83+
wf_type
84+
)
8285

8386
context = _from_idl_context(request.context)
8487
_input = self._options.object_encoder.decode(
@@ -93,7 +96,10 @@ def handle_workflow_worker_rpc(
9396
}
9497

9598
data_attributes = DataAttributes(
96-
data_attributes_types, self._options.object_encoder, current_data_attributes
99+
data_attributes_types,
100+
data_attributes_prefix_types,
101+
self._options.object_encoder,
102+
current_data_attributes,
97103
)
98104

99105
search_attributes_types = self._registry.get_search_attribute_types(wf_type)
@@ -176,6 +182,9 @@ def handle_workflow_state_wait_until(
176182
internal_channel_types = self._registry.get_internal_channel_type_store(wf_type)
177183
signal_channel_types = self._registry.get_signal_channel_types(wf_type)
178184
data_attributes_types = self._registry.get_data_attribute_types(wf_type)
185+
data_attributes_prefix_types = self._registry.get_data_attribute_prefix_types(
186+
wf_type
187+
)
179188

180189
context = _from_idl_context(request.context)
181190
_input = self._options.object_encoder.decode(
@@ -190,7 +199,10 @@ def handle_workflow_state_wait_until(
190199
}
191200

192201
data_attributes = DataAttributes(
193-
data_attributes_types, self._options.object_encoder, current_data_attributes
202+
data_attributes_types,
203+
data_attributes_prefix_types,
204+
self._options.object_encoder,
205+
current_data_attributes,
194206
)
195207

196208
search_attributes_types = self._registry.get_search_attribute_types(wf_type)
@@ -257,6 +269,9 @@ def handle_workflow_state_execute(
257269
internal_channel_types = self._registry.get_internal_channel_type_store(wf_type)
258270
signal_channel_types = self._registry.get_signal_channel_types(wf_type)
259271
data_attributes_types = self._registry.get_data_attribute_types(wf_type)
272+
data_attributes_prefix_types = self._registry.get_data_attribute_prefix_types(
273+
wf_type
274+
)
260275
context = _from_idl_context(request.context)
261276

262277
_input = self._options.object_encoder.decode(
@@ -271,7 +286,10 @@ def handle_workflow_state_execute(
271286
}
272287

273288
data_attributes = DataAttributes(
274-
data_attributes_types, self._options.object_encoder, current_data_attributes
289+
data_attributes_types,
290+
data_attributes_prefix_types,
291+
self._options.object_encoder,
292+
current_data_attributes,
275293
)
276294

277295
search_attributes_types = self._registry.get_search_attribute_types(wf_type)

0 commit comments

Comments
 (0)