Skip to content

Commit 8e914ee

Browse files
authored
Merge branch 'serverlessworkflow:main' into main
2 parents ceca99a + 166bd12 commit 8e914ee

File tree

1 file changed

+90
-7
lines changed

1 file changed

+90
-7
lines changed

serverlessworkflow/sdk/state_machine_generator.py

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,67 @@ def sleep_state_details(self):
259259

260260
def event_state_details(self):
261261
if isinstance(self.current_state, EventState):
262-
self.state_to_machine_state(["event_state", "state"])
262+
state = self.state_to_machine_state(["event_state", "state"])
263+
if self.get_actions:
264+
if on_events := self.current_state.onEvents:
265+
state.initial = [] if len(on_events) > 1 else on_events[0]
266+
for i, oe in enumerate(on_events):
267+
state.add_substate(
268+
oe_state := self.state_machine.state_cls(
269+
oe_name := f"onEvent {i}"
270+
)
271+
)
272+
273+
# define initial state
274+
if i == 0 and len(on_events) > 1:
275+
state.initial = [oe_state.name]
276+
elif i == 0 and len(on_events) == 1:
277+
state.initial = oe_state.name
278+
else:
279+
state.initial.append(oe_state.name)
280+
281+
event_names = []
282+
for ie, event in enumerate(oe.eventRefs):
283+
oe_state.add_substate(
284+
ns := self.state_machine.state_cls(event)
285+
)
286+
ns.tags = ["event"]
287+
self.get_action_event(state=ns, e_name=event)
288+
event_names.append(event)
289+
290+
# define initial state
291+
if ie == 0 and len(oe.eventRefs) > 1:
292+
oe_state.initial = [event]
293+
elif ie == 0 and len(oe.eventRefs) == 1:
294+
oe_state.initial = event
295+
else:
296+
oe_state.initial.append(event)
297+
298+
if self.current_state.exclusive:
299+
oe_state.add_substate(
300+
ns := self.state_machine.state_cls(
301+
action_name := f"action {ie}"
302+
)
303+
)
304+
self.state_machine.add_transition(
305+
trigger="",
306+
source=f"{self.current_state.name}.{oe_name}.{event}",
307+
dest=f"{self.current_state.name}.{oe_name}.{action_name}",
308+
)
309+
self.generate_actions_info(
310+
machine_state=ns,
311+
state_name=f"{self.current_state.name}.{oe_name}.{action_name}",
312+
actions=oe.actions,
313+
action_mode=oe.actionMode,
314+
)
315+
if not self.current_state.exclusive and oe.actions:
316+
self.generate_actions_info(
317+
machine_state=oe_state,
318+
state_name=f"{self.current_state.name}.{oe_name}",
319+
actions=oe.actions,
320+
action_mode=oe.actionMode,
321+
initial_states=event_names,
322+
)
263323

264324
def foreach_state_details(self):
265325
if isinstance(self.current_state, ForEachState):
@@ -352,6 +412,7 @@ def generate_actions_info(
352412
state_name: str,
353413
actions: List[Dict[str, Action]],
354414
action_mode: str = "sequential",
415+
initial_states: List[str] = [],
355416
):
356417
if self.get_actions:
357418
parallel_states = []
@@ -386,7 +447,11 @@ def generate_actions_info(
386447
ns := self.state_machine.state_cls(name)
387448
)
388449
ns.tags = ["event"]
389-
self.get_action_event(state=ns, e_name=name)
450+
self.get_action_event(
451+
state=ns,
452+
e_name=action.eventRef.triggerEventRef,
453+
er_name=action.eventRef.resultEventRef,
454+
)
390455
if name:
391456
if action_mode == "sequential":
392457
if i < len(actions) - 1:
@@ -438,19 +503,36 @@ def generate_actions_info(
438503
)
439504
ns.tags = ["event"]
440505
self.get_action_event(
441-
state=ns, e_name=next_name
506+
state=ns,
507+
e_name=action.eventRef.triggerEventRef,
508+
er_name=action.eventRef.resultEventRef,
442509
)
443510
self.state_machine.add_transition(
444511
trigger="",
445512
source=f"{state_name}.{name}",
446513
dest=f"{state_name}.{next_name}",
447514
)
448-
if i == 0:
515+
if i == 0 and not initial_states:
449516
machine_state.initial = name
517+
elif i == 0 and initial_states:
518+
for init_s in initial_states:
519+
self.state_machine.add_transition(
520+
trigger="",
521+
source=f"{state_name}.{init_s}",
522+
dest=f"{state_name}.{name}",
523+
)
450524
elif action_mode == "parallel":
451525
parallel_states.append(name)
452-
if action_mode == "parallel":
526+
if action_mode == "parallel" and not initial_states:
453527
machine_state.initial = parallel_states
528+
elif action_mode == "parallel" and initial_states:
529+
for init_s in initial_states:
530+
for ps in parallel_states:
531+
self.state_machine.add_transition(
532+
trigger="",
533+
source=f"{state_name}.{init_s}",
534+
dest=f"{state_name}.{ps}",
535+
)
454536

455537
def get_action_function(self, state: NestedState, f_name: str):
456538
if self.workflow.functions:
@@ -460,13 +542,14 @@ def get_action_function(self, state: NestedState, f_name: str):
460542
state.metadata = {"function": current_function}
461543
break
462544

463-
def get_action_event(self, state: NestedState, e_name: str):
545+
def get_action_event(self, state: NestedState, e_name: str, er_name: str = ""):
464546
if self.workflow.events:
465547
for event in self.workflow.events:
466548
current_event = event.serialize().__dict__
467549
if current_event["name"] == e_name:
468550
state.metadata = {"event": current_event}
469-
break
551+
if current_event["name"] == er_name:
552+
state.metadata = {"result_event": current_event}
470553

471554
def subflow_state_name(self, action: Action, subflow: Workflow):
472555
return (

0 commit comments

Comments
 (0)