Skip to content

Commit 1d396a0

Browse files
authored
[ML][Pipelines]Fix CLI bug when control flow in pipeline component (Azure#28428)
* add tests * re-record
1 parent 50a34a0 commit 1d396a0

22 files changed

+5459
-2139
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/pipeline_component.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ class _AnonymousPipelineComponentSchema(AnonymousAssetSchema, PipelineComponentS
189189
def make(self, data, **kwargs):
190190
from azure.ai.ml.entities._component.pipeline_component import PipelineComponent
191191

192+
# pipeline jobs post process is required before init of pipeline component: it converts control node dict
193+
# to entity.
194+
# however @post_load invocation order is not guaranteed, so we need to call it explicitly here.
195+
_post_load_pipeline_jobs(self.context, data)
196+
192197
return PipelineComponent(
193198
base_path=self.context[BASE_PATH_CONTEXT_KEY],
194199
**data,

sdk/ml/azure-ai-ml/tests/pipeline_job/e2etests/test_control_flow_pipeline.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class TestDoWhile(TestConditionalNodeInPipeline):
124124
def test_pipeline_with_do_while_node(self, client: MLClient, randstr: Callable[[], str]) -> None:
125125
params_override = [{"name": randstr('name')}]
126126
pipeline_job = load_job(
127-
"./tests/test_configs/dsl_pipeline/pipeline_with_do_while/pipeline.yml",
127+
"./tests/test_configs/pipeline_jobs/control_flow/do_while/pipeline.yml",
128128
params_override=params_override,
129129
)
130130
created_pipeline = assert_job_cancel(pipeline_job, client)
@@ -138,7 +138,7 @@ def test_pipeline_with_do_while_node(self, client: MLClient, randstr: Callable[[
138138
def test_do_while_pipeline_with_primitive_inputs(self, client: MLClient, randstr: Callable[[], str]) -> None:
139139
params_override = [{"name": randstr('name')}]
140140
pipeline_job = load_job(
141-
"./tests/test_configs/dsl_pipeline/pipeline_with_do_while/pipeline_with_primitive_inputs.yml",
141+
"./tests/test_configs/pipeline_jobs/control_flow/do_while/pipeline_with_primitive_inputs.yml",
142142
params_override=params_override,
143143
)
144144
created_pipeline = assert_job_cancel(pipeline_job, client)
@@ -208,3 +208,42 @@ def test_output_binding_foreach_node(self, client: MLClient, randstr: Callable):
208208
'type': 'parallel_for'
209209
}
210210
assert_foreach(client, randstr("job_name"), source, expected_node)
211+
212+
213+
def assert_control_flow_in_pipeline_component(client, component_path, pipeline_path):
214+
params_override = [{"component": component_path}]
215+
pipeline_job = load_job(
216+
pipeline_path,
217+
params_override=params_override,
218+
)
219+
created_pipeline = assert_job_cancel(pipeline_job, client)
220+
pipeline_job_dict = created_pipeline._to_rest_object().as_dict()
221+
222+
pipeline_job_dict = omit_with_wildcard(pipeline_job_dict, *omit_fields)
223+
assert pipeline_job_dict["properties"]["jobs"] == {}
224+
225+
226+
class TestControlFLowPipelineComponent(TestConditionalNodeInPipeline):
227+
def test_if_else(self, client: MLClient, randstr: Callable[[], str]):
228+
assert_control_flow_in_pipeline_component(
229+
client=client,
230+
component_path="./if_else/simple_pipeline.yml",
231+
pipeline_path="./tests/test_configs/pipeline_jobs/control_flow/control_flow_with_pipeline_component.yml"
232+
)
233+
234+
@pytest.mark.skip(
235+
reason="TODO(2177353): check why recorded tests failure."
236+
)
237+
def test_do_while(self, client: MLClient, randstr: Callable[[], str]):
238+
assert_control_flow_in_pipeline_component(
239+
client=client,
240+
component_path="./do_while/pipeline_component.yml",
241+
pipeline_path="./tests/test_configs/pipeline_jobs/control_flow/control_flow_with_pipeline_component.yml"
242+
)
243+
244+
def test_foreach(self, client: MLClient, randstr: Callable[[], str]):
245+
assert_control_flow_in_pipeline_component(
246+
client=client,
247+
component_path="./parallel_for/simple_pipeline.yml",
248+
pipeline_path="./tests/test_configs/pipeline_jobs/control_flow/control_flow_with_pipeline_component.yml"
249+
)

sdk/ml/azure-ai-ml/tests/pipeline_job/unittests/test_pipeline_job_validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def pipeline_with_compute_binding(compute_name: str):
654654
def test_pipeline_with_invalid_do_while_node(self) -> None:
655655
with pytest.raises(ValidationError) as exception:
656656
load_job(
657-
"./tests/test_configs/dsl_pipeline/pipeline_with_do_while/invalid_pipeline.yml",
657+
"./tests/test_configs/pipeline_jobs/control_flow/do_while/invalid_pipeline.yml",
658658
)
659659
error_message_str = re.findall(r"(\{.*\})", exception.value.args[0].replace("\n", ""))[0]
660660
error_messages = json.loads(error_message_str)

sdk/ml/azure-ai-ml/tests/recordings/component/e2etests/test_component.pyTestComponenttest_create_pipeline_component_from_job.json

Lines changed: 334 additions & 331 deletions
Large diffs are not rendered by default.

sdk/ml/azure-ai-ml/tests/recordings/component/e2etests/test_component.pyTestComponenttest_helloworld_nested_pipeline_component.json

Lines changed: 91 additions & 90 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)