Skip to content

Commit 3a8b199

Browse files
N-giveNathan Givens
andauthored
IWF-681 add data attributes API methods (#94)
* IWF-681 add missing data attributes API methods --------- Co-authored-by: Nathan Givens <[email protected]>
1 parent b6ab93e commit 3a8b199

File tree

7 files changed

+150
-78
lines changed

7 files changed

+150
-78
lines changed

iwf/client.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import inspect
2-
from typing import Any, Callable, Optional, Type, TypeVar, Union
2+
from typing import Any, Callable, List, Optional, Type, TypeVar, Union
33

44
from typing_extensions import deprecated
55

66
from iwf.client_options import ClientOptions
7-
from iwf.errors import InvalidArgumentError, WorkflowDefinitionError
7+
from iwf.errors import InvalidArgumentError, NotRegisteredError, WorkflowDefinitionError
88
from iwf.iwf_api.models import (
99
SearchAttribute,
1010
SearchAttributeKeyAndType,
@@ -162,6 +162,71 @@ def stop_workflow(
162162
):
163163
return self._unregistered_client.stop_workflow(workflow_id, "", options)
164164

165+
def get_all_workflow_data_attributes(
166+
self,
167+
workflow_class: type[ObjectWorkflow],
168+
workflow_id: str,
169+
workflow_run_id: str = "",
170+
):
171+
return self.get_workflow_data_attributes(
172+
workflow_class, workflow_id, workflow_run_id, None
173+
)
174+
175+
def get_workflow_data_attributes(
176+
self,
177+
workflow_class: type[ObjectWorkflow],
178+
workflow_id: str,
179+
workflow_run_id: str = "",
180+
keys: Optional[List[str]] = None,
181+
):
182+
wf_type = get_workflow_type_by_class(workflow_class)
183+
data_attr_type_store = self._registry.get_data_attribute_types(wf_type)
184+
if keys:
185+
for key in keys:
186+
if not data_attr_type_store.is_valid_name_or_prefix(key):
187+
raise NotRegisteredError(
188+
f"key {key} is not registered in workflow {wf_type}"
189+
)
190+
191+
response = self._unregistered_client.get_workflow_data_attributes(
192+
workflow_id, workflow_run_id, keys
193+
)
194+
195+
if not response.objects:
196+
raise RuntimeError("data attributes not returned")
197+
198+
res = {}
199+
for kv in response.objects:
200+
k = unset_to_none(kv.key)
201+
if k and kv.value:
202+
res[kv.key] = self._options.object_encoder.decode(
203+
kv.value, data_attr_type_store.get_type(k)
204+
)
205+
206+
return res
207+
208+
def set_workflow_data_attributes(
209+
self,
210+
workflow_class: type[ObjectWorkflow],
211+
workflow_id: str,
212+
workflow_run_id: str = "",
213+
data_attributes: dict[str, Any] = dict(),
214+
):
215+
wf_type = get_workflow_type_by_class(workflow_class)
216+
data_attr_type_store = self._registry.get_data_attribute_types(wf_type)
217+
for key, value in data_attributes.items():
218+
if not data_attr_type_store.is_valid_name_or_prefix(key):
219+
raise NotRegisteredError(f"data attribute {key} is not registered")
220+
221+
data_attr_type = data_attr_type_store.get_type(key)
222+
if not isinstance(value, data_attr_type):
223+
raise NotRegisteredError(
224+
f"data attribute {key} is not registered as {type(value)}"
225+
)
226+
return self._unregistered_client.set_workflow_data_attributes(
227+
workflow_id, workflow_run_id, data_attributes
228+
)
229+
165230
def invoke_rpc(
166231
self,
167232
workflow_id: str,

iwf/data_attributes.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,46 @@
1-
from typing import Any, Optional, Union
1+
from typing import Any, Union
22

33
from iwf.errors import WorkflowDefinitionError
44
from iwf.iwf_api.models import EncodedObject
55
from iwf.object_encoder import ObjectEncoder
6+
from iwf.type_store import TypeStore
67

78

89
class DataAttributes:
9-
_type_store: dict[str, Optional[type]]
10-
_prefix_type_store: dict[str, Optional[type]]
10+
_type_store: TypeStore
1111
_object_encoder: ObjectEncoder
1212
_current_values: dict[str, Union[EncodedObject, None]]
1313
_updated_values_to_return: dict[str, EncodedObject]
1414

1515
def __init__(
1616
self,
17-
type_store: dict[str, Optional[type]],
18-
prefix_type_store: dict[str, Optional[type]],
17+
type_store: TypeStore,
1918
object_encoder: ObjectEncoder,
2019
current_values: dict[str, Union[EncodedObject, None]],
2120
):
2221
self._object_encoder = object_encoder
2322
self._type_store = type_store
24-
self._prefix_type_store = prefix_type_store
2523
self._current_values = current_values
2624
self._updated_values_to_return = {}
2725

2826
def get_data_attribute(self, key: str) -> Any:
29-
is_registered, registered_type = self._validate_key_and_get_type(key)
27+
is_registered = self._type_store.is_valid_name_or_prefix(key)
3028
if not is_registered:
3129
raise WorkflowDefinitionError(f"data attribute %s is not registered {key}")
3230

3331
encoded_object = self._current_values.get(key)
3432
if encoded_object is None:
3533
return None
3634

35+
registered_type = self._type_store.get_type(key)
3736
return self._object_encoder.decode(encoded_object, registered_type)
3837

3938
def set_data_attribute(self, key: str, value: Any):
40-
is_registered, registered_type = self._validate_key_and_get_type(key)
39+
is_registered = self._type_store.is_valid_name_or_prefix(key)
4140
if not is_registered:
4241
raise WorkflowDefinitionError(f"data attribute %s is not registered {key}")
4342

43+
registered_type = self._type_store.get_type(key)
4444
if registered_type is not None and not isinstance(value, registered_type):
4545
raise WorkflowDefinitionError(
4646
f"data attribute %s is of the right type {registered_type}"
@@ -52,13 +52,3 @@ def set_data_attribute(self, key: str, value: Any):
5252

5353
def get_updated_values_to_return(self) -> dict[str, EncodedObject]:
5454
return self._updated_values_to_return
55-
56-
def _validate_key_and_get_type(self, key) -> tuple[bool, Optional[type]]:
57-
if key in self._type_store:
58-
return (True, self._type_store.get(key))
59-
60-
for prefix, t in self._prefix_type_store.items():
61-
if key.startswith(prefix):
62-
return (True, t)
63-
64-
return (False, None)

iwf/registry.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ class Registry:
1616
_state_store: dict[str, dict[str, WorkflowState]]
1717
_internal_channel_type_store: dict[str, TypeStore]
1818
_signal_channel_type_store: dict[str, dict[str, Optional[type]]]
19-
_data_attribute_types: dict[str, dict[str, Optional[type]]]
20-
_data_attribute_prefix_types: dict[str, dict[str, Optional[type]]]
19+
_data_attribute_types: dict[str, TypeStore]
2120
_search_attribute_types: dict[str, dict[str, SearchAttributeValueType]]
2221
_rpc_infos: dict[str, dict[str, RPCInfo]]
2322

@@ -28,7 +27,6 @@ def __init__(self):
2827
self._internal_channel_type_store = dict()
2928
self._signal_channel_type_store = dict()
3029
self._data_attribute_types = dict()
31-
self._data_attribute_prefix_types = dict()
3230
self._search_attribute_types = {}
3331
self._rpc_infos = dict()
3432

@@ -38,7 +36,6 @@ def add_workflow(self, wf: ObjectWorkflow):
3836
self._register_internal_channels(wf)
3937
self._register_signal_channels(wf)
4038
self._register_data_attributes(wf)
41-
self._register_data_attribute_prefix_types(wf)
4239
self._register_search_attributes(wf)
4340
self._register_workflow_rpcs(wf)
4441

@@ -77,19 +74,14 @@ def get_internal_channel_type_store(self, wf_type: str) -> TypeStore:
7774
def get_signal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]:
7875
return self._signal_channel_type_store[wf_type]
7976

80-
def get_data_attribute_types(self, wf_type: str) -> dict[str, Optional[type]]:
77+
def get_data_attribute_types(self, wf_type: str) -> TypeStore:
8178
return self._data_attribute_types[wf_type]
8279

8380
def get_search_attribute_types(
8481
self, wf_type: str
8582
) -> dict[str, SearchAttributeValueType]:
8683
return self._search_attribute_types[wf_type]
8784

88-
def get_data_attribute_prefix_types(
89-
self, wf_type: str
90-
) -> dict[str, Optional[type]]:
91-
return self._data_attribute_prefix_types[wf_type]
92-
9385
def get_rpc_infos(self, wf_type: str) -> dict[str, RPCInfo]:
9486
return self._rpc_infos[wf_type]
9587

@@ -123,10 +115,13 @@ def _register_signal_channels(self, wf: ObjectWorkflow):
123115

124116
def _register_data_attributes(self, wf: ObjectWorkflow):
125117
wf_type = get_workflow_type(wf)
126-
data_attribute_types: dict[str, Optional[type]] = {}
118+
data_attribute_types: TypeStore = TypeStore(Type.DATA_ATTRIBUTE)
127119
for field in wf.get_persistence_schema().persistence_fields:
128-
if field.field_type == PersistenceFieldType.DataAttribute:
129-
data_attribute_types[field.key] = field.value_type
120+
if (
121+
field.field_type == PersistenceFieldType.DataAttribute
122+
or field.field_type == PersistenceFieldType.DataAttributePrefix
123+
):
124+
data_attribute_types.add_data_attribute_def(field)
130125
self._data_attribute_types[wf_type] = data_attribute_types
131126

132127
def _register_search_attributes(self, wf: ObjectWorkflow):
@@ -146,14 +141,6 @@ def _register_search_attributes(self, wf: ObjectWorkflow):
146141
types[field.key] = sa_type
147142
self._search_attribute_types[wf_type] = types
148143

149-
def _register_data_attribute_prefix_types(self, wf: ObjectWorkflow):
150-
wf_type = get_workflow_type(wf)
151-
data_attribute_prefix_types: dict[str, Optional[type]] = {}
152-
for field in wf.get_persistence_schema().persistence_fields:
153-
if field.field_type == PersistenceFieldType.DataAttributePrefix:
154-
data_attribute_prefix_types[field.key] = field.value_type
155-
self._data_attribute_prefix_types[wf_type] = data_attribute_prefix_types
156-
157144
def _register_workflow_state(self, wf):
158145
wf_type = get_workflow_type(wf)
159146
state_map = {}

iwf/tests/test_persistence_data_attributes.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from iwf.workflow_context import WorkflowContext
1616
from iwf.workflow_options import WorkflowOptions
1717
from iwf.workflow_state import T, WorkflowState
18-
from iwf.rpc import rpc
1918

2019
initial_da_1 = "initial_da_1"
2120
initial_da_value_1 = "value_1"
@@ -35,6 +34,18 @@
3534
test_da_prefix_key_2 = "test-da-prefix-2"
3635
test_da_prefix_value_1 = "test-da-value-1"
3736
test_da_prefix_value_2 = "test-da-value-2"
37+
test_da_set_key = "test_da_set_key"
38+
test_da_set_value = "test_da_set_value"
39+
40+
expected_final_das = {
41+
initial_da_1: initial_da_value_1,
42+
initial_da_2: final_initial_da_value_2,
43+
test_da_1: final_test_da_value_1,
44+
test_da_2: final_test_da_value_2,
45+
test_da_prefix_key_1: test_da_prefix_value_1,
46+
test_da_prefix_key_2: test_da_prefix_value_2,
47+
test_da_set_key: test_da_set_value,
48+
}
3849

3950

4051
class DataAttributeRWState(WorkflowState[None]):
@@ -82,17 +93,7 @@ def get_persistence_schema(self) -> PersistenceSchema:
8293
PersistenceField.data_attribute_def(test_da_1, str),
8394
PersistenceField.data_attribute_def(test_da_2, int),
8495
PersistenceField.data_attribute_prefix_def(test_da_prefix, str),
85-
)
86-
87-
@rpc()
88-
def test_persistence_read(self, pers: Persistence):
89-
return (
90-
pers.get_data_attribute(initial_da_1),
91-
pers.get_data_attribute(initial_da_2),
92-
pers.get_data_attribute(test_da_1),
93-
pers.get_data_attribute(test_da_2),
94-
pers.get_data_attribute(test_da_prefix_key_1),
95-
pers.get_data_attribute(test_da_prefix_key_2),
96+
PersistenceField.data_attribute_def(test_da_set_key, str),
9697
)
9798

9899

@@ -113,19 +114,22 @@ def test_persistence_data_attributes_workflow(self):
113114
},
114115
)
115116

116-
self.client.start_workflow(
117+
wf_run_id = self.client.start_workflow(
117118
PersistenceDataAttributesWorkflow, wf_id, 100, None, start_options
118119
)
120+
self.client.set_workflow_data_attributes(
121+
PersistenceDataAttributesWorkflow,
122+
wf_id,
123+
wf_run_id,
124+
{
125+
test_da_set_key: test_da_set_value,
126+
},
127+
)
119128
self.client.wait_for_workflow_completion(wf_id, None)
120129

121-
res = self.client.invoke_rpc(
122-
wf_id, PersistenceDataAttributesWorkflow.test_persistence_read
130+
all_data_attributes = self.client.get_all_workflow_data_attributes(
131+
PersistenceDataAttributesWorkflow, wf_id, wf_run_id
123132
)
124-
assert res == [
125-
final_initial_da_value_1,
126-
final_initial_da_value_2,
127-
final_test_da_value_1,
128-
final_test_da_value_2,
129-
test_da_prefix_value_1,
130-
test_da_prefix_value_2,
131-
]
133+
for k, v in expected_final_das.items():
134+
assert k in all_data_attributes
135+
assert all_data_attributes[k] == v

iwf/type_store.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44
from iwf.communication_schema import CommunicationMethod
55
from iwf.errors import WorkflowDefinitionError, NotRegisteredError
6+
from iwf.persistence_schema import PersistenceField, PersistenceFieldType
67

78

89
class Type(Enum):
910
INTERNAL_CHANNEL = 1
11+
DATA_ATTRIBUTE = 2
1012
# TODO: extend to other types
11-
# DATA_ATTRIBUTE = 2
1213
# SIGNAL_CHANNEL = 3
1314

1415

@@ -44,6 +45,17 @@ def add_internal_channel_def(self, obj: CommunicationMethod):
4445
)
4546
self._do_add_to_store(obj.is_prefix, obj.name, obj.value_type)
4647

48+
def add_data_attribute_def(self, obj: PersistenceField):
49+
if self._class_type != Type.DATA_ATTRIBUTE:
50+
raise ValueError(
51+
f"Cannot add internal channel definition to {self._class_type}"
52+
)
53+
self._do_add_to_store(
54+
obj.field_type == PersistenceFieldType.DataAttributePrefix,
55+
obj.key,
56+
obj.value_type,
57+
)
58+
4759
def _validate_name(self, name: str) -> bool:
4860
if name in self._name_to_type_store:
4961
return True

0 commit comments

Comments
 (0)