Skip to content

Commit 2eaee59

Browse files
authored
[ML][Pipelines] Fix if else block intersection check (Azure#29274)
* fix intersection * fix tests
1 parent 448689d commit 2eaee59

File tree

6 files changed

+1479
-37
lines changed

6 files changed

+1479
-37
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/condition_node.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ---------------------------------------------------------
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
4-
from marshmallow import fields, post_dump
4+
from marshmallow import fields, post_dump, ValidationError
55

66
from azure.ai.ml._schema import StringTransformedEnum
77
from azure.ai.ml._schema.core.fields import DataBindingStr, NodeBindingStr, UnionField
@@ -24,4 +24,25 @@ def simplify_blocks(self, data, **kwargs): # pylint: disable=unused-argument, n
2424
for block in block_keys:
2525
if isinstance(data.get(block), list) and len(data.get(block)) == 1:
2626
data[block] = data.get(block)[0]
27+
28+
# validate blocks intersection
29+
def _normalize_blocks(key):
30+
blocks = data.get(key, [])
31+
if blocks:
32+
if not isinstance(blocks, list):
33+
blocks = [blocks]
34+
else:
35+
blocks = []
36+
return blocks
37+
38+
true_block = _normalize_blocks("true_block")
39+
false_block = _normalize_blocks("false_block")
40+
41+
if not true_block and not false_block:
42+
raise ValidationError("True block and false block cannot be empty at the same time.")
43+
44+
intersection = set(true_block).intersection(set(false_block))
45+
if intersection:
46+
raise ValidationError(f"True block and false block cannot contain same nodes: {intersection}")
47+
2748
return data

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/condition_node.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -121,29 +121,4 @@ def _validate_params(self, raise_error=True) -> MutableValidationResult:
121121
message=f"'{name}' of dsl.condition has invalid binding expression: {block}, {error_tail}",
122122
)
123123

124-
def _get_intersection(lst1, lst2):
125-
if not lst1:
126-
lst1 = []
127-
if not lst2:
128-
lst2 = []
129-
return list(set(lst1) & set(lst2))
130-
131-
intersection = _get_intersection(self.true_block, self.false_block)
132-
133-
if not self.true_block and not self.false_block:
134-
validation_result.append_error(
135-
yaml_path="true_block",
136-
message="'true_block' and 'false_block' of dsl.condition node cannot both be empty.",
137-
)
138-
elif self.true_block is self.false_block:
139-
validation_result.append_error(
140-
yaml_path="true_block",
141-
message="'true_block' and 'false_block' of dsl.condition node cannot be the same object.",
142-
)
143-
elif intersection:
144-
validation_result.append_error(
145-
yaml_path="true_block",
146-
message="'true_block' and 'false_block' of dsl.condition has intersection.",
147-
)
148-
149124
return validation_result.try_raise(self._get_validation_error_target(), raise_error=raise_error)

sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_controlflow_pipeline.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,43 @@ def condition_pipeline():
356356
"type": "if_else",
357357
}
358358

359+
@pytest.mark.skipif(condition=not is_live(), reason="TODO(2177353): check why recorded tests failure.")
360+
def test_if_else_multiple_blocks_subgraph(self, client: MLClient):
361+
hello_world_component_no_paths = load_component(
362+
source="./tests/test_configs/components/helloworld_component_no_paths.yml"
363+
)
364+
basic_component = load_component(
365+
source="./tests/test_configs/components/component_with_conditional_output/spec.yaml"
366+
)
367+
368+
@pipeline()
369+
def subgraph():
370+
hello_world_component_no_paths(component_in_number=2)
371+
372+
@pipeline(
373+
compute="cpu-cluster",
374+
)
375+
def condition_pipeline():
376+
result = basic_component()
377+
378+
node1 = hello_world_component_no_paths(component_in_number=1)
379+
380+
node2 = subgraph()
381+
382+
condition(condition=result.outputs.output, true_block=[node1, node2])
383+
384+
pipeline_job = condition_pipeline()
385+
with include_private_preview_nodes_in_pipeline():
386+
pipeline_job = assert_job_cancel(pipeline_job, client)
387+
388+
dsl_pipeline_job_dict = omit_with_wildcard(pipeline_job._to_rest_object().as_dict(), *omit_fields)
389+
assert dsl_pipeline_job_dict["properties"]["jobs"]["conditionnode"] == {
390+
"_source": "DSL",
391+
"condition": "${{parent.jobs.result.outputs.output}}",
392+
"true_block": ["${{parent.jobs.node1}}", "${{parent.jobs.node2}}"],
393+
"type": "if_else",
394+
}
395+
359396

360397
@pytest.mark.skipif(
361398
condition=is_live(),

sdk/ml/azure-ai-ml/tests/dsl/unittests/test_controlflow_pipeline.py

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import pytest
2+
from marshmallow import ValidationError
23

34
from azure.ai.ml import Input, load_component
45
from azure.ai.ml.constants._component import ComponentSource
56
from azure.ai.ml.dsl import pipeline
67
from azure.ai.ml.dsl._condition import condition
78
from azure.ai.ml.dsl._parallel_for import parallel_for
9+
from azure.ai.ml.entities._job.pipeline._io import InputOutputBase
810
from azure.ai.ml.exceptions import ValidationException
11+
from test_utilities.utils import omit_with_wildcard
912

1013
from .._util import _DSL_TIMEOUT_SECOND
1114

@@ -67,10 +70,139 @@ def condition_pipeline():
6770
# true block and false block has intersection
6871
condition(condition=result.outputs.output, false_block=[node1, node2], true_block=[node1])
6972

70-
with pytest.raises(ValidationException) as e:
73+
with pytest.raises(ValidationError) as e:
7174
pipeline_job = condition_pipeline()
7275
pipeline_job._validate(raise_error=True)
73-
assert "'true_block' and 'false_block' of dsl.condition has intersection" in str(e.value)
76+
assert "True block and false block cannot contain same nodes: {'${{parent.jobs.node1}}'" in str(e.value)
77+
78+
@pipeline(compute="cpu-cluster")
79+
def condition_pipeline():
80+
result = basic_component()
81+
node1 = hello_world_component_no_paths()
82+
node2 = hello_world_component_no_paths()
83+
# true block and false block has intersection
84+
condition(condition=result.outputs.output, false_block=[node1], true_block=[node2])
85+
86+
# no error raise
87+
pipeline_job = condition_pipeline()
88+
pipeline_job._validate(raise_error=True)
89+
90+
def test_create_dsl_condition_illegal_cases(self):
91+
basic_component = load_component(
92+
source="./tests/test_configs/components/component_with_conditional_output/spec.yaml"
93+
)
94+
basic_node = basic_component()
95+
96+
with pytest.raises(ValidationException) as e:
97+
node = condition(condition=1, true_block=basic_node)
98+
node._validate(raise_error=True)
99+
100+
assert f"must be an instance of {str}, {bool} or {InputOutputBase}" in str(e.value)
101+
102+
with pytest.raises(ValidationException) as e:
103+
node = condition(condition=basic_node.outputs.output3, true_block=basic_node)
104+
node._validate(raise_error=True)
105+
106+
assert "must have 'is_control' field with value 'True'" in str(e.value)
107+
108+
with pytest.raises(ValidationError) as e:
109+
node = condition(condition="${{parent.jobs.xxx.outputs.output}}")
110+
node._validate(raise_error=True)
111+
112+
assert "True block and false block cannot be empty at the same time." in str(e.value)
113+
114+
with pytest.raises(ValidationError) as e:
115+
node = condition(
116+
condition="${{parent.jobs.xxx.outputs.output}}", true_block=basic_node, false_block=basic_node
117+
)
118+
node._validate(raise_error=True)
119+
120+
assert "True block and false block cannot contain same nodes" in str(e.value)
121+
122+
with pytest.raises(ValidationException) as e:
123+
node = condition(condition="xxx", true_block=basic_node)
124+
node._validate(raise_error=True)
125+
126+
assert "condition has invalid binding expression: xxx" in str(e.value)
127+
128+
def test_condition_node(self):
129+
hello_world_component_no_paths = load_component(
130+
source="./tests/test_configs/components/helloworld_component_no_paths.yml"
131+
)
132+
node1 = hello_world_component_no_paths()
133+
node1.name = "node1"
134+
node2 = hello_world_component_no_paths()
135+
node2.name = "node2"
136+
control_node = condition(
137+
condition="${{parent.jobs.condition_predicate.outputs.output}}", false_block=node1, true_block=node2
138+
)
139+
assert control_node._to_rest_object() == {
140+
"_source": "DSL",
141+
"condition": "${{parent.jobs.condition_predicate.outputs.output}}",
142+
"false_block": "${{parent.jobs.node1}}",
143+
"true_block": "${{parent.jobs.node2}}",
144+
"type": "if_else",
145+
}
146+
147+
# test boolean type condition
148+
control_node = condition(condition=True, false_block=node1, true_block=node2)
149+
assert control_node._to_rest_object() == {
150+
"_source": "DSL",
151+
"condition": True,
152+
"false_block": "${{parent.jobs.node1}}",
153+
"true_block": "${{parent.jobs.node2}}",
154+
"type": "if_else",
155+
}
156+
157+
def test_condition_pipeline(self):
158+
basic_component = load_component(
159+
source="./tests/test_configs/components/component_with_conditional_output/spec.yaml"
160+
)
161+
162+
hello_world_component_no_paths = load_component(
163+
source="./tests/test_configs/components/helloworld_component_no_paths.yml"
164+
)
165+
166+
@pipeline(
167+
name="test_mldesigner_component_with_conditional_output",
168+
compute="cpu-cluster",
169+
)
170+
def condition_pipeline():
171+
result = basic_component()
172+
node1 = hello_world_component_no_paths(component_in_number=1)
173+
node2 = hello_world_component_no_paths(component_in_number=2)
174+
condition(condition=result.outputs.output, false_block=node1, true_block=node2)
175+
176+
pipeline_job = condition_pipeline()
177+
omit_fields = [
178+
"name",
179+
"properties.display_name",
180+
"properties.jobs.*.componentId",
181+
"properties.settings",
182+
]
183+
dsl_pipeline_job_dict = omit_with_wildcard(pipeline_job._to_rest_object().as_dict(), *omit_fields)
184+
assert dsl_pipeline_job_dict["properties"]["jobs"] == {
185+
"conditionnode": {
186+
"_source": "DSL",
187+
"condition": "${{parent.jobs.result.outputs.output}}",
188+
"false_block": "${{parent.jobs.node1}}",
189+
"true_block": "${{parent.jobs.node2}}",
190+
"type": "if_else",
191+
},
192+
"node1": {
193+
"_source": "YAML.COMPONENT",
194+
"inputs": {"component_in_number": {"job_input_type": "literal", "value": "1"}},
195+
"name": "node1",
196+
"type": "command",
197+
},
198+
"node2": {
199+
"_source": "YAML.COMPONENT",
200+
"inputs": {"component_in_number": {"job_input_type": "literal", "value": "2"}},
201+
"name": "node2",
202+
"type": "command",
203+
},
204+
"result": {"_source": "YAML.COMPONENT", "name": "result", "type": "command"},
205+
}
74206

75207

76208
class TestDoWhilePipelineUT(TestControlFlowPipelineUT):

sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_controlflow_pipeline_job.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,24 @@ class TestIfElseUI(TestControlFlowPipelineJobUT):
2626
[
2727
# None true & None false
2828
(
29-
ValidationException,
29+
ValidationError,
3030
"./tests/test_configs/pipeline_jobs/invalid/if_else/none_true_none_false.yml",
31-
"of dsl.condition node cannot both be empty",
32-
'"path": "jobs.conditionnode.true_block",',
31+
"True block and false block cannot be empty at the same time.",
32+
"",
3333
),
3434
# None true & empty false
3535
(
36-
ValidationException,
36+
ValidationError,
3737
"./tests/test_configs/pipeline_jobs/invalid/if_else/none_true_empty_false.yml",
38-
"of dsl.condition node cannot both be empty",
39-
'"path": "jobs.conditionnode.true_block",',
38+
"True block and false block cannot be empty at the same time.",
39+
"",
4040
),
4141
# true & false intersection
4242
(
43-
ValidationException,
43+
ValidationError,
4444
"./tests/test_configs/pipeline_jobs/invalid/if_else/true_false_intersection.yml",
45-
"of dsl.condition has intersection",
46-
'"path": "jobs.conditionnode.true_block",',
45+
"True block and false block cannot contain same nodes:",
46+
"",
4747
),
4848
# invalid binding
4949
(

0 commit comments

Comments
 (0)