Skip to content

Commit db691f7

Browse files
authored
IWF-543: Support overriding state options (#79)
* IWF-543: Support overriding state options * IWF-543: Fix types * IWF-543: Add test * IWF-543: Lint * IWF-543: Lint
1 parent efb2583 commit db691f7

File tree

3 files changed

+141
-9
lines changed

3 files changed

+141
-9
lines changed

iwf/state_decision.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import typing
44

55
from iwf.iwf_api.models import WorkflowConditionalClose, WorkflowConditionalCloseType
6+
from iwf.workflow_state_options import WorkflowStateOptions
67

78
if typing.TYPE_CHECKING:
89
from iwf.registry import Registry
910
from iwf.workflow_state import WorkflowState
1011

1112
from dataclasses import dataclass
12-
from typing import Any, List, Union
13+
from typing import Any, List, Union, Optional
1314

1415
from iwf.iwf_api.models.state_decision import StateDecision as IdlStateDecision
1516

@@ -49,9 +50,14 @@ def force_fail_workflow(cls, output: Any = None) -> StateDecision:
4950

5051
@classmethod
5152
def single_next_state(
52-
cls, state: Union[str, type[WorkflowState]], state_input: Any = None
53+
cls,
54+
state: Union[str, type[WorkflowState]],
55+
state_input: Any = None,
56+
state_options_override: Optional[WorkflowStateOptions] = None,
5357
) -> StateDecision:
54-
return StateDecision([StateMovement.create(state, state_input)])
58+
return StateDecision(
59+
[StateMovement.create(state, state_input, state_options_override)]
60+
)
5561

5662
@classmethod
5763
def multi_next_states(

iwf/state_movement.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import typing
4-
from typing import Union
4+
from typing import Union, Optional
55

66
from iwf.errors import WorkflowDefinitionError
77

@@ -14,12 +14,12 @@
1414
from dataclasses import dataclass
1515
from typing import Any
1616

17+
1718
from iwf.iwf_api.models.state_movement import StateMovement as IdlStateMovement
1819

1920
from iwf.object_encoder import ObjectEncoder
2021

21-
from iwf.workflow_state_options import _to_idl_state_options
22-
22+
from iwf.workflow_state_options import _to_idl_state_options, WorkflowStateOptions
2323

2424
reserved_state_id_prefix = "_SYS_"
2525

@@ -35,6 +35,7 @@
3535
class StateMovement:
3636
state_id: str
3737
state_input: Any = None
38+
state_options_override: Optional[WorkflowStateOptions] = None
3839

3940
dead_end: typing.ClassVar[StateMovement]
4041

@@ -52,7 +53,10 @@ def force_fail_workflow(cls, output: Any = None) -> StateMovement:
5253

5354
@classmethod
5455
def create(
55-
cls, state: Union[str, type[WorkflowState]], state_input: Any = None
56+
cls,
57+
state: Union[str, type[WorkflowState]],
58+
state_input: Any = None,
59+
state_options_override: Optional[WorkflowStateOptions] = None,
5660
) -> StateMovement:
5761
if isinstance(state, str):
5862
state_id = state
@@ -64,7 +68,7 @@ def create(
6468
state_id = get_state_id_by_class(state)
6569
if state_id.startswith(reserved_state_id_prefix):
6670
raise WorkflowDefinitionError("cannot use reserved stateId")
67-
return StateMovement(state_id, state_input)
71+
return StateMovement(state_id, state_input, state_options_override)
6872

6973

7074
StateMovement.dead_end = StateMovement(dead_end_sys_state_id)
@@ -83,9 +87,14 @@ def _to_idl_state_movement(
8387
should_skip_wait_until,
8488
)
8589

90+
if movement.state_options_override is not None:
91+
options = movement.state_options_override
92+
else:
93+
options = state.get_state_options()
94+
8695
idl_state_options = _to_idl_state_options(
8796
should_skip_wait_until(state),
88-
state.get_state_options(),
97+
options,
8998
registry.get_state_store(wf_type),
9099
)
91100

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import inspect
2+
import time
3+
import unittest
4+
5+
from iwf.client import Client
6+
from iwf.command_request import CommandRequest
7+
from iwf.command_results import CommandResults
8+
from iwf.communication import Communication
9+
from iwf.persistence_schema import PersistenceSchema, PersistenceField
10+
from iwf.workflow_options import WorkflowOptions
11+
12+
from iwf.iwf_api.models import RetryPolicy, WaitUntilApiFailurePolicy, IDReusePolicy
13+
from iwf.persistence import Persistence
14+
from iwf.state_decision import StateDecision
15+
from iwf.state_schema import StateSchema
16+
from iwf.tests.worker_server import registry
17+
from iwf.workflow import ObjectWorkflow
18+
from iwf.workflow_context import WorkflowContext
19+
from iwf.workflow_state import T, WorkflowState
20+
from iwf.workflow_state_options import WorkflowStateOptions
21+
22+
output_da = "output_da"
23+
24+
25+
class InitState(WorkflowState[str]):
26+
def wait_until(
27+
self,
28+
ctx: WorkflowContext,
29+
input: T,
30+
persistence: Persistence,
31+
communication: Communication,
32+
) -> CommandRequest:
33+
persistence.set_data_attribute(
34+
output_da, str(input) + "_InitState_waitUntil_completed"
35+
)
36+
return CommandRequest.empty()
37+
38+
def execute(
39+
self,
40+
ctx: WorkflowContext,
41+
input: T,
42+
command_results: CommandResults,
43+
persistence: Persistence,
44+
communication: Communication,
45+
) -> StateDecision:
46+
data = persistence.get_data_attribute(output_da)
47+
data += "_InitState_execute_completed"
48+
return StateDecision.single_next_state(
49+
NonInitState,
50+
data,
51+
WorkflowStateOptions(
52+
wait_until_api_retry_policy=RetryPolicy(maximum_attempts=2),
53+
proceed_to_execute_when_wait_until_retry_exhausted=WaitUntilApiFailurePolicy.PROCEED_ON_FAILURE,
54+
),
55+
)
56+
57+
58+
class NonInitState(WorkflowState[str]):
59+
def wait_until(
60+
self,
61+
ctx: WorkflowContext,
62+
input: T,
63+
persistence: Persistence,
64+
communication: Communication,
65+
) -> CommandRequest:
66+
raise RuntimeError("test failure")
67+
68+
def execute(
69+
self,
70+
ctx: WorkflowContext,
71+
input: T,
72+
command_results: CommandResults,
73+
persistence: Persistence,
74+
communication: Communication,
75+
) -> StateDecision:
76+
data = str(input) + "_NonInitState_execute_completed"
77+
return StateDecision.graceful_complete_workflow(data)
78+
79+
def get_state_options(self) -> WorkflowStateOptions:
80+
return WorkflowStateOptions(
81+
wait_until_api_retry_policy=RetryPolicy(maximum_attempts=1),
82+
proceed_to_execute_when_wait_until_retry_exhausted=WaitUntilApiFailurePolicy.FAIL_WORKFLOW_ON_FAILURE,
83+
)
84+
85+
86+
class StateOptionsOverrideWorkflow(ObjectWorkflow):
87+
def get_persistence_schema(self) -> PersistenceSchema:
88+
return PersistenceSchema.create(
89+
PersistenceField.data_attribute_def(output_da, str),
90+
)
91+
92+
def get_workflow_states(self) -> StateSchema:
93+
return StateSchema.with_starting_state(InitState(), NonInitState())
94+
95+
96+
class TestStateOptionsOverrideWorkflow(unittest.TestCase):
97+
@classmethod
98+
def setUpClass(cls):
99+
wf = StateOptionsOverrideWorkflow()
100+
registry.add_workflow(wf)
101+
cls.client = Client(registry)
102+
103+
def test_override(self):
104+
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"
105+
self.client.start_workflow(
106+
StateOptionsOverrideWorkflow,
107+
wf_id,
108+
10,
109+
"input",
110+
WorkflowOptions(workflow_id_reuse_policy=IDReusePolicy.DISALLOW_REUSE),
111+
)
112+
output = self.client.wait_for_workflow_completion(wf_id)
113+
114+
assert (
115+
output
116+
== "input_InitState_waitUntil_completed_InitState_execute_completed_NonInitState_execute_completed"
117+
)

0 commit comments

Comments
 (0)