Skip to content

Commit 87e9799

Browse files
authored
Raise useful error on workflow.connect_with for wrong labels (#2110)
* Typehints * Fix typo in error * Raise connection error on connect_with for wrong names * Add Workflow.safe_connect_with * Add test for Workflow.safe_connect_with * Switch to adding parameter 'permissive' to Workflow.connect_with()
1 parent 8b721c0 commit 87e9799

File tree

2 files changed

+79
-17
lines changed

2 files changed

+79
-17
lines changed

src/ansys/dpf/core/workflow.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
"""Workflow."""
2424

25+
from __future__ import annotations
26+
2527
from enum import Enum
2628
import logging
2729
import os
@@ -714,7 +716,12 @@ def output_names(self):
714716
return out
715717

716718
@version_requires("3.0")
717-
def connect_with(self, left_workflow, output_input_names=None):
719+
def connect_with(
720+
self,
721+
left_workflow: Workflow,
722+
output_input_names: Union[tuple[str, str], dict[str, str]] = None,
723+
permissive: bool = True,
724+
):
718725
"""Prepend a given workflow to the current workflow.
719726
720727
Updates the current workflow to include all the operators of the workflow given as argument.
@@ -724,15 +731,18 @@ def connect_with(self, left_workflow, output_input_names=None):
724731
725732
Parameters
726733
----------
727-
left_workflow : core.Workflow
734+
left_workflow:
728735
The given workflow's outputs are chained with the current workflow's inputs.
729-
output_input_names : str tuple, str dict optional
736+
output_input_names:
730737
Map used to connect the outputs of the given workflow to the inputs of the current
731738
workflow.
732739
Check the names of available inputs and outputs for each workflow using
733740
`Workflow.input_names` and `Workflow.output_names`.
734741
The default is ``None``, in which case it tries to connect each output of the
735742
left_workflow with an input of the current workflow with the same name.
743+
permissive:
744+
Whether to filter 'output_input_names' to only keep available connections.
745+
Otherwise raise an error if 'output_input_names' contains unavailable inputs or outputs.
736746
737747
Examples
738748
--------
@@ -791,24 +801,40 @@ def connect_with(self, left_workflow, output_input_names=None):
791801
792802
"""
793803
if output_input_names:
794-
core_api = self._server.get_api_for_type(
795-
capi=data_processing_capi.DataProcessingCAPI,
796-
grpcapi=data_processing_grpcapi.DataProcessingGRPCAPI,
797-
)
798-
map = object_handler.ObjHandler(
799-
data_processing_api=core_api,
800-
internal_obj=self._api.workflow_create_connection_map_for_object(self),
801-
)
802804
if isinstance(output_input_names, tuple):
803-
self._api.workflow_add_entry_connection_map(
804-
map, output_input_names[0], output_input_names[1]
805+
output_input_names = {output_input_names[0]: output_input_names[1]}
806+
if isinstance(output_input_names, dict):
807+
core_api = self._server.get_api_for_type(
808+
capi=data_processing_capi.DataProcessingCAPI,
809+
grpcapi=data_processing_grpcapi.DataProcessingGRPCAPI,
810+
)
811+
map = object_handler.ObjHandler(
812+
data_processing_api=core_api,
813+
internal_obj=self._api.workflow_create_connection_map_for_object(self),
805814
)
806-
elif isinstance(output_input_names, dict):
807-
for key in output_input_names:
808-
self._api.workflow_add_entry_connection_map(map, key, output_input_names[key])
815+
output_names = left_workflow.output_names
816+
input_names = self.input_names
817+
if permissive:
818+
output_input_names = dict(
819+
filter(
820+
lambda item: item[0] in left_workflow.output_names
821+
and item[1] in self.input_names,
822+
output_input_names.items(),
823+
)
824+
)
825+
for output_name, input_name in output_input_names.items():
826+
if output_name not in output_names:
827+
raise ValueError(
828+
f"Cannot connect workflow output '{output_name}'. Exposed outputs are:\n{output_names}"
829+
)
830+
elif input_name not in input_names:
831+
raise ValueError(
832+
f"Cannot connect workflow input '{input_name}'. Exposed inputs are:\n{input_names}"
833+
)
834+
self._api.workflow_add_entry_connection_map(map, output_name, input_name)
809835
else:
810836
raise TypeError(
811-
"output_input_names argument is expect" "to be either a str tuple or a str dict"
837+
"output_input_names argument is expected to be either a str tuple or a str dict"
812838
)
813839
self._api.work_flow_connect_with_specified_names(self, left_workflow, map)
814840
else:

tests/test_workflow.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,42 @@ def test_connect_with_dict_workflow(cyclic_lin_rst, cyclic_ds, server_type):
695695
fc = wf2.get_output("u", dpf.core.types.fields_container)
696696

697697

698+
def test_workflow_connect_raise_wrong_label(server_type):
699+
workflow1 = dpf.core.Workflow()
700+
forward_1 = dpf.core.operators.utility.forward()
701+
workflow1.set_output_name("output", forward_1.outputs.any)
702+
703+
workflow2 = dpf.core.Workflow()
704+
forward_2 = dpf.core.operators.utility.forward()
705+
workflow2.set_input_name("input", forward_2.inputs.any)
706+
707+
with pytest.raises(
708+
ValueError, match="Cannot connect workflow output 'out'. Exposed outputs are:\n"
709+
):
710+
workflow2.connect_with(workflow1, output_input_names={"out": "input"}, permissive=False)
711+
with pytest.raises(
712+
ValueError, match="Cannot connect workflow input 'in'. Exposed inputs are:\n"
713+
):
714+
workflow2.connect_with(workflow1, output_input_names={"output": "in"}, permissive=False)
715+
workflow2.connect_with(workflow1, output_input_names={"output": "input"}, permissive=False)
716+
717+
718+
def test_workflow_connect_with_permissive(server_type):
719+
workflow1 = dpf.core.Workflow()
720+
forward_1 = dpf.core.operators.utility.forward()
721+
workflow1.set_output_name("output", forward_1.outputs.any)
722+
723+
workflow2 = dpf.core.Workflow()
724+
forward_2 = dpf.core.operators.utility.forward()
725+
workflow2.set_input_name("input", forward_2.inputs.any)
726+
727+
workflow2.connect_with(workflow1, output_input_names={"out": "input"})
728+
729+
workflow2.connect_with(workflow1, output_input_names={"output": "in"})
730+
731+
workflow2.connect_with(workflow1, output_input_names=("output", "input"))
732+
733+
698734
@pytest.mark.xfail(raises=dpf.core.errors.ServerTypeError)
699735
def test_info_workflow(allkindofcomplexity, server_type):
700736
model = dpf.core.Model(allkindofcomplexity, server=server_type)

0 commit comments

Comments
 (0)