Skip to content

Commit 7011ee5

Browse files
authored
Support e2e run for if_else node in pipeline (Azure#26780)
* add e2e test * update test * fix test * fix pipeline tests * add comment * fix test
1 parent 7549cb1 commit 7011ee5

File tree

12 files changed

+1463
-12
lines changed

12 files changed

+1463
-12
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/pipeline_component.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def PipelineJobsField():
7272
],
7373
}
7474

75+
# Note: the private node types only available when private preview flag opened before init of pipeline job
76+
# schema class.
7577
if is_private_preview_enabled():
7678
pipeline_enable_job_type[ControlFlowType.DO_WHILE] = [NestedField(DoWhileSchema, unknown=INCLUDE)]
7779
pipeline_enable_job_type[ControlFlowType.IF_ELSE] = [NestedField(ConditionNodeSchema, unknown=INCLUDE)]

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/condition_node.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def _create_schema_for_validation(
3636

3737
return ConditionNodeSchema(context=context)
3838

39+
@classmethod
40+
def _from_rest_object(cls, obj: dict) -> "ConditionNode":
41+
return cls(**obj)
42+
3943
def _to_dict(self) -> Dict:
4044
return self._dump_for_validation()
4145

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/control_flow_node.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,6 @@ def _get_validation_error_target(cls) -> ErrorTarget:
6767
"""
6868
return ErrorTarget.PIPELINE
6969

70-
@classmethod
71-
def _from_rest_object(cls, obj: dict, reference_node_list: list) -> "ControlFlowNode":
72-
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory
73-
74-
node_type = obj.get(CommonYamlFields.TYPE, None)
75-
load_from_rest_obj_func = pipeline_node_factory.get_load_from_rest_object_func(_type=node_type)
76-
return load_from_rest_obj_func(obj, reference_node_list)
77-
7870

7971
class LoopNode(ControlFlowNode, ABC):
8072
"""
@@ -132,3 +124,11 @@ def _get_data_binding_expression_value(expression, regex):
132124
@staticmethod
133125
def _is_loop_node_dict(obj):
134126
return obj.get(CommonYamlFields.TYPE, None) in [ControlFlowType.DO_WHILE]
127+
128+
@classmethod
129+
def _from_rest_object(cls, obj: dict, reference_node_list: list) -> "ControlFlowNode":
130+
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory
131+
132+
node_type = obj.get(CommonYamlFields.TYPE, None)
133+
load_from_rest_obj_func = pipeline_node_factory.get_load_from_rest_object_func(_type=node_type)
134+
return load_from_rest_obj_func(obj, reference_node_list)

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/pipeline/_load_component.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from azure.ai.ml.dsl._component_func import to_component_func
1818
from azure.ai.ml.dsl._overrides_definition import OverrideDefinition
1919
from azure.ai.ml.entities._builders import BaseNode, Command, Import, Parallel, Spark, Sweep
20+
from azure.ai.ml.entities._builders.condition_node import ConditionNode
2021
from azure.ai.ml.entities._builders.do_while import DoWhile
2122
from azure.ai.ml.entities._builders.pipeline import Pipeline
2223
from azure.ai.ml.entities._component.component import Component
@@ -81,6 +82,12 @@ def __init__(self):
8182
load_from_rest_object_func=DoWhile._from_rest_object,
8283
nested_schema=None,
8384
)
85+
self.register_type(
86+
_type=ControlFlowType.IF_ELSE,
87+
create_instance_func=None,
88+
load_from_rest_object_func=ConditionNode._from_rest_object,
89+
nested_schema=None,
90+
)
8491

8592
@classmethod
8693
def _get_func(cls, _type: str, funcs):

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/pipeline/pipeline_job.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from azure.ai.ml.constants._component import ComponentSource
2929
from azure.ai.ml.constants._job.pipeline import ValidationErrorCode
3030
from azure.ai.ml.entities._builders import BaseNode
31+
from azure.ai.ml.entities._builders.condition_node import ConditionNode
3132
from azure.ai.ml.entities._builders.control_flow_node import LoopNode
3233
from azure.ai.ml.entities._builders.import_node import Import
3334
from azure.ai.ml.entities._builders.parallel import Parallel
@@ -266,11 +267,13 @@ def _customized_validate(self) -> MutableValidationResult:
266267

267268
def _validate_input(self):
268269
validation_result = self._create_empty_validation_result()
270+
# TODO(1979547): refine this logic: not all nodes have `_get_input_binding_dict` method
269271
used_pipeline_inputs = set(
270272
itertools.chain(
271273
*[
272274
self.component._get_input_binding_dict(node if not isinstance(node, LoopNode) else node.body)[0]
273-
for node in self.jobs.values()
275+
for node in self.jobs.values() if not isinstance(node, ConditionNode)
276+
# condition node has no inputs
274277
]
275278
)
276279
)

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_component_operations.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
from .._utils._experimental import experimental
3939
from .._utils.utils import is_data_binding_expression
40+
from ..entities._builders.condition_node import ConditionNode
4041
from ..entities._component.automl_component import AutoMLComponent
4142
from ..entities._component.pipeline_component import PipelineComponent
4243
from ._code_operations import CodeOperations
@@ -531,6 +532,8 @@ def resolve_base_node(name, node: BaseNode):
531532
self._job_operations._resolve_arm_id_for_automl_job(job_instance, resolver, inside_pipeline=True)
532533
elif isinstance(job_instance, BaseNode):
533534
resolve_base_node(key, job_instance)
535+
elif isinstance(job_instance, ConditionNode):
536+
pass
534537
else:
535538
msg = f"Non supported job type in Pipeline: {type(job_instance)}"
536539
raise ComponentException(

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_job_operations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898

9999
from .._utils._experimental import experimental
100100
from ..constants._component import ComponentSource
101+
from ..entities._builders.condition_node import ConditionNode
101102
from ..entities._job.pipeline._io import InputOutputBase, _GroupAttrDict, PipelineInput
102103
from ._component_operations import ComponentOperations
103104
from ._compute_operations import ComputeOperations
@@ -419,7 +420,8 @@ def _validate(
419420

420421
for node_name, node in job.jobs.items():
421422
try:
422-
if not isinstance(node, DoWhile):
423+
# TODO(1979547): refactor, not all nodes have compute
424+
if not isinstance(node, (DoWhile, ConditionNode)):
423425
node.compute = self._try_get_compute_arm_id(node.compute)
424426
except Exception as e: # pylint: disable=broad-except
425427
validation_result.append_error(yaml_path=f"jobs.{node_name}.compute", message=str(e))

sdk/ml/azure-ai-ml/tests/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
from azure.ai.ml import MLClient, load_component, load_job
1414
from azure.ai.ml._restclient.registry_discovery import AzureMachineLearningWorkspaces as ServiceClientRegistryDiscovery
1515
from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope
16-
from azure.ai.ml._utils._asset_utils import get_object_hash
1716
from azure.ai.ml._utils.utils import hash_dict
18-
from azure.ai.ml.constants._common import GitProperties
1917
from azure.ai.ml.entities import AzureBlobDatastore, Component
2018
from azure.ai.ml.entities._assets import Data, Model
2119
from azure.ai.ml.entities._component.parallel_component import ParallelComponent
@@ -526,6 +524,7 @@ def credentialless_datastore(client: MLClient, storage_account_name: str) -> Azu
526524
def enable_pipeline_private_preview_features(mocker: MockFixture):
527525
mocker.patch("azure.ai.ml.entities._job.pipeline.pipeline_job.is_private_preview_enabled", return_value=True)
528526
mocker.patch("azure.ai.ml.dsl._pipeline_component_builder.is_private_preview_enabled", return_value=True)
527+
mocker.patch("azure.ai.ml._schema.pipeline.pipeline_component.is_private_preview_enabled", return_value=True)
529528

530529

531530
@pytest.fixture()
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import contextlib
2+
import pytest
3+
4+
from azure.ai.ml._schema.pipeline import PipelineJobSchema
5+
from .._util import _DSL_TIMEOUT_SECOND
6+
from test_utilities.utils import _PYTEST_TIMEOUT_METHOD, omit_with_wildcard
7+
from azure.ai.ml._schema.pipeline.pipeline_component import PipelineJobsField
8+
from devtools_testutils import AzureRecordedTestCase
9+
10+
from azure.ai.ml import MLClient, load_component
11+
from azure.ai.ml.dsl import pipeline
12+
from azure.ai.ml.dsl._condition import condition
13+
14+
15+
@contextlib.contextmanager
16+
def include_private_preview_nodes_in_pipeline():
17+
original_jobs = PipelineJobSchema._declared_fields["jobs"]
18+
PipelineJobSchema._declared_fields["jobs"] = PipelineJobsField()
19+
20+
try:
21+
yield
22+
finally:
23+
PipelineJobSchema._declared_fields["jobs"] = original_jobs
24+
25+
26+
@pytest.mark.usefixtures(
27+
"enable_environment_id_arm_expansion",
28+
"enable_pipeline_private_preview_features",
29+
"mock_code_hash",
30+
"mock_component_hash",
31+
"recorded_test",
32+
)
33+
@pytest.mark.timeout(timeout=_DSL_TIMEOUT_SECOND, method=_PYTEST_TIMEOUT_METHOD)
34+
@pytest.mark.e2etest
35+
class TestDynamicPipeline(AzureRecordedTestCase):
36+
def test_dsl_condition_pipeline(self, client: MLClient):
37+
# update jobs field to include private preview nodes
38+
39+
hello_world_component_no_paths = load_component(
40+
source="./tests/test_configs/components/helloworld_component_no_paths.yml"
41+
)
42+
basic_component = load_component(
43+
source="./tests/test_configs/components/component_with_conditional_output/spec.yaml"
44+
)
45+
46+
@pipeline(
47+
name="test_mldesigner_component_with_conditional_output",
48+
compute="cpu-cluster",
49+
)
50+
def condition_pipeline():
51+
result = basic_component(str_param="abc", int_param=1)
52+
53+
node1 = hello_world_component_no_paths(component_in_number=1)
54+
node2 = hello_world_component_no_paths(component_in_number=2)
55+
condition(condition=result.outputs.output, false_block=node1, true_block=node2)
56+
57+
pipeline_job = condition_pipeline()
58+
59+
# include private preview nodes
60+
with include_private_preview_nodes_in_pipeline():
61+
pipeline_job = client.jobs.create_or_update(pipeline_job)
62+
63+
omit_fields = [
64+
"name",
65+
"properties.display_name",
66+
"properties.jobs.*.componentId",
67+
"properties.settings",
68+
]
69+
dsl_pipeline_job_dict = omit_with_wildcard(pipeline_job._to_rest_object().as_dict(), *omit_fields)
70+
assert dsl_pipeline_job_dict["properties"]["jobs"] == {
71+
"conditionnode": {
72+
"condition": "${{parent.jobs.result.outputs.output}}",
73+
"false_block": "${{parent.jobs.node1}}",
74+
"true_block": "${{parent.jobs.node2}}",
75+
"type": "if_else",
76+
},
77+
"node1": {
78+
"_source": "REMOTE.WORKSPACE.COMPONENT",
79+
"computeId": None,
80+
"display_name": None,
81+
"distribution": None,
82+
"environment_variables": {},
83+
"inputs": {"component_in_number": {"job_input_type": "literal", "value": "1"}},
84+
"limits": None,
85+
"name": "node1",
86+
"outputs": {},
87+
"resources": None,
88+
"tags": {},
89+
"type": "command",
90+
"properties": {},
91+
},
92+
"node2": {
93+
"_source": "REMOTE.WORKSPACE.COMPONENT",
94+
"computeId": None,
95+
"display_name": None,
96+
"distribution": None,
97+
"environment_variables": {},
98+
"inputs": {"component_in_number": {"job_input_type": "literal", "value": "2"}},
99+
"limits": None,
100+
"name": "node2",
101+
"outputs": {},
102+
"resources": None,
103+
"tags": {},
104+
"type": "command",
105+
"properties": {},
106+
},
107+
"result": {
108+
"_source": "REMOTE.WORKSPACE.COMPONENT",
109+
"computeId": None,
110+
"display_name": None,
111+
"distribution": None,
112+
"environment_variables": {},
113+
"inputs": {
114+
"int_param": {"job_input_type": "literal", "value": "1"},
115+
"str_param": {"job_input_type": "literal", "value": "abc"},
116+
},
117+
"limits": None,
118+
"name": "result",
119+
"outputs": {},
120+
"resources": None,
121+
"tags": {},
122+
"type": "command",
123+
"properties": {},
124+
},
125+
}
126+
127+
@pytest.mark.skip(reason="TODO(2027778): Verify after primitive condition is supported.")
128+
def test_dsl_condition_pipeline_with_primitive_input(self, client: MLClient):
129+
hello_world_component_no_paths = load_component(
130+
source="./tests/test_configs/components/helloworld_component_no_paths.yml"
131+
)
132+
133+
@pipeline(
134+
name="test_mldesigner_component_with_conditional_output",
135+
compute="cpu-cluster",
136+
)
137+
def condition_pipeline():
138+
node1 = hello_world_component_no_paths(component_in_number=1)
139+
node2 = hello_world_component_no_paths(component_in_number=2)
140+
condition(condition=True, false_block=node1, true_block=node2)
141+
142+
pipeline_job = condition_pipeline()
143+
client.jobs.create_or_update(pipeline_job)

0 commit comments

Comments
 (0)