Skip to content

Commit 6397955

Browse files
author
Alan Christie
committed
refactor: Refactored using decoder 2.5.0
1 parent 3883412 commit 6397955

File tree

4 files changed

+89
-25
lines changed

4 files changed

+89
-25
lines changed

tests/job-definitions/job-definitions.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,10 @@ jobs:
136136
splitsmiles:
137137
command: >-
138138
copyf.py {{ inputFile }}
139+
# Simulate multiple output files...
140+
variables:
141+
outputs:
142+
properties:
143+
outputBase:
144+
creates: '{{ outputBase }}_*.smi'
145+
type: files

tests/test_workflow_engine_examples.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,6 @@ def test_workflow_engine_simple_python_molprops_with_options(basic_engine):
398398
assert project_file_exists(output_file_2)
399399

400400

401-
@pytest.mark.skip(reason="WIP")
402401
def test_workflow_engine_simple_python_fanout(basic_engine):
403402
# Arrange
404403
md, da = basic_engine

workflow/decoder.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ def get_steps(definition: dict[str, Any]) -> list[dict[str, Any]]:
6161
return response
6262

6363

64+
def get_step(definition: dict[str, Any], name: str) -> dict[str, Any]:
65+
"""Given a Workflow definition this function returns a named step
66+
(if it exists)."""
67+
steps: list[dict[str, Any]] = get_steps(definition)
68+
for step in steps:
69+
if step["name"] == name:
70+
return step
71+
return {}
72+
73+
6474
def get_name(definition: dict[str, Any]) -> str:
6575
"""Given a Workflow definition this function returns its name."""
6676
return str(definition.get("name", ""))
@@ -117,8 +127,8 @@ def get_step_input_variable_names(
117127

118128

119129
def get_step_workflow_variable_mapping(*, step: dict[str, Any]) -> list[Translation]:
120-
"""Returns a list of workflow vaiable name to step variable name tuples
121-
for the given step."""
130+
"""Returns a list of workflow vaiable name to step variable name
131+
Translation objects for the given step."""
122132
variable_mapping: list[Translation] = []
123133
if "variable-mapping" in step:
124134
for v_map in step["variable-mapping"]:
@@ -134,8 +144,9 @@ def get_step_workflow_variable_mapping(*, step: dict[str, Any]) -> list[Translat
134144
def get_step_prior_step_variable_mapping(
135145
*, step: dict[str, Any]
136146
) -> dict[str, list[Translation]]:
137-
"""Returns list of translate objects, indexed by prior step name,
138-
that identify source step vaiable name to this step's variable name."""
147+
"""Returns list of Translation objects, indexed by prior step name,
148+
that identify source step (output) variable name to this step's (input)
149+
variable name."""
139150
variable_mapping: dict[str, list[Translation]] = {}
140151
if "variable-mapping" in step:
141152
for v_map in step["variable-mapping"]:

workflow/workflow_engine.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
import sys
2727
from typing import Any, Optional
2828

29-
from decoder.decoder import TextEncoding, decode
29+
import decoder.decoder as job_defintion_decoder
30+
from decoder.decoder import TextEncoding
3031
from google.protobuf.message import Message
3132
from informaticsmatters.protobuf.datamanager.pod_message_pb2 import PodMessage
3233
from informaticsmatters.protobuf.datamanager.workflow_message_pb2 import WorkflowMessage
@@ -40,6 +41,7 @@
4041

4142
from .decoder import (
4243
Translation,
44+
get_step,
4345
get_step_prior_step_variable_mapping,
4446
get_step_workflow_variable_mapping,
4547
)
@@ -127,7 +129,7 @@ def _handle_workflow_start_message(self, r_wfid: str) -> None:
127129
# Launch it.
128130
# If there's a launch problem the step (and running workflow) will have
129131
# and error, stopping it. There will be no Pod event as the launch has failed.
130-
self._launch(rwf=rwf_response, step=first_step)
132+
self._launch(wf=wf_response, rwf=rwf_response, step=first_step)
131133

132134
def _handle_workflow_stop_message(self, r_wfid: str) -> None:
133135
"""Logic to handle a STOP message."""
@@ -263,7 +265,7 @@ def _handle_pod_message(self, msg: PodMessage) -> None:
263265
# There's another step!
264266
# For this simple logic it is the next step.
265267
next_step = wf_response["steps"][step_index + 1]
266-
self._launch(rwf=rwf_response, step=next_step)
268+
self._launch(wf=wf_response, rwf=rwf_response, step=next_step)
267269

268270
# Something was started (or there was a launch error and the step
269271
# and running workflow error will have been set).
@@ -278,28 +280,21 @@ def _handle_pod_message(self, msg: PodMessage) -> None:
278280
success=True,
279281
)
280282

281-
def _validate_step_command(
282-
self,
283-
*,
284-
running_workflow_id: str,
285-
step: dict[str, Any],
286-
running_workflow_variables: dict[str, Any],
287-
) -> str | dict[str, Any]:
288-
"""Returns an error message if the command isn't valid.
289-
Without a message we return all the variables that were (successfully)
290-
applied to the command."""
291-
283+
def _get_step_job(self, *, step: dict[str, Any]) -> dict[str, Any]:
284+
"""Gets the Job definition for a given Step."""
292285
# We get the Job from the step specification, which must contain
293286
# the keys "collection", "job", and "version". Here we assume that
294287
# the workflow definition has passed the RUN-level validation
295288
# which means we can get these values.
289+
assert "specification" in step
296290
step_spec: dict[str, Any] = step["specification"]
297291
job_collection: str = step_spec["collection"]
298292
job_job: str = step_spec["job"]
299293
job_version: str = step_spec["version"]
300294
job, _ = self._wapi_adapter.get_job(
301295
collection=job_collection, job=job_job, version=job_version
302296
)
297+
303298
_LOGGER.debug(
304299
"API.get_job(%s, %s, %s) returned: -\n%s",
305300
job_collection,
@@ -308,6 +303,19 @@ def _validate_step_command(
308303
str(job),
309304
)
310305

306+
return job
307+
308+
def _validate_step_command(
309+
self,
310+
*,
311+
running_workflow_id: str,
312+
step: dict[str, Any],
313+
running_workflow_variables: dict[str, Any],
314+
) -> str | dict[str, Any]:
315+
"""Returns an error message if the command isn't valid.
316+
Without a message we return all the variables that were (successfully)
317+
applied to the command."""
318+
311319
# Start with any variables provided in the step's specification.
312320
# This will be ou t"all variables" map for this step,
313321
# whcih we will add to (and maybe even over-write)...
@@ -345,12 +353,15 @@ def _validate_step_command(
345353
all_variables[tr.out] = prior_step["variables"][tr.in_]
346354

347355
# Now ... can the command be compiled!?
348-
message, success = decode(
356+
job: dict[str, Any] = self._get_step_job(step=step)
357+
message, success = job_defintion_decoder.decode(
349358
job["command"], all_variables, "command", TextEncoding.JINJA2_3_0
350359
)
351360
return all_variables if success else message
352361

353-
def _launch(self, *, rwf: dict[str, Any], step: dict[str, Any]) -> None:
362+
def _launch(
363+
self, *, wf: dict[str, Any], rwf: dict[str, Any], step: dict[str, Any]
364+
) -> None:
354365
step_name: str = step["name"]
355366
rwf_id: str = rwf["id"]
356367
project_id = rwf["project"]["id"]
@@ -380,17 +391,53 @@ def _launch(self, *, rwf: dict[str, Any], step: dict[str, Any]) -> None:
380391
# A step replication number,
381392
# used only for steps expected to run in parallel (even if just once)
382393
step_replication_number: int = 0
394+
# Do we replicate this step (run it more than once)?
395+
# We do if a variable in this step's mapping block
396+
# refers to an output of a prior step whose type is 'files'.
397+
# If the prior step is a 'splitter' we populate the 'replication_values' array
398+
# with the list of files the prior step genrated for its output.
383399
replication_values: list[str] = []
384-
source_is_splitter: bool = False
385400
iter_variable: str | None = None
401+
tr_map: dict[str, list[Translation]] = get_step_prior_step_variable_mapping(
402+
step=step
403+
)
404+
for p_step_name, tr_list in tr_map.items():
405+
# We need to get the Job definition for each step
406+
# and then check whether the (ouptu) variable is of type 'files'...
407+
wf_step: dict[str, Any] = get_step(wf, p_step_name)
408+
assert wf_step
409+
job_definition: dict[str, Any] = self._get_step_job(step=wf_step)
410+
jd_outputs: dict[str, Any] = job_defintion_decoder.get_outputs(
411+
job_definition
412+
)
413+
for tr in tr_list:
414+
if jd_outputs.get(tr.in_, {}).get("type") == "files":
415+
iter_variable = tr.out
416+
# Get the prior running step's output values
417+
response, _ = self._wapi_adapter.get_running_workflow_step_by_name(
418+
name=p_step_name,
419+
running_workflow_id=rwf_id,
420+
)
421+
rwfs_id = response["id"]
422+
assert rwfs_id
423+
result, _ = (
424+
self._wapi_adapter.get_running_workflow_step_output_values_for_output(
425+
running_workflow_step_id=rwfs_id,
426+
output_variable=tr.in_,
427+
)
428+
)
429+
replication_values = result["output"].copy()
430+
break
431+
# Stop if we've got an iteration variable
432+
if iter_variable:
433+
break
386434

387435
num_step_instances: int = max(1, len(replication_values))
388436
for iteration in range(num_step_instances):
389437

390438
# If we are replicating this step then we must replace the step's variable
391439
# with a value expected for this iteration.
392-
if source_is_splitter:
393-
assert iter_variable
440+
if iter_variable:
394441
iter_value: str = replication_values[iteration]
395442
_LOGGER.info(
396443
"Replicating step: %s iteration=%s variable=%s value=%s",
@@ -427,7 +474,7 @@ def _launch(self, *, rwf: dict[str, Any], step: dict[str, Any]) -> None:
427474
step_replication_number=step_replication_number,
428475
)
429476
lr: LaunchResult = self._instance_launcher.launch(launch_parameters=lp)
430-
rwfs_id: str | None = lr.running_workflow_step_id
477+
rwfs_id = lr.running_workflow_step_id
431478
assert rwfs_id
432479

433480
if lr.error_num:

0 commit comments

Comments
 (0)