Skip to content

Commit ea403c4

Browse files
authored
[Pipeline] Fix condition node failed when expression is group input (Azure#30650)
* fix condition node failed when expression is group input * update test case * fix code style * fix failed test case * fix failed test case
1 parent b002b6e commit ea403c4

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/pipeline/_pipeline_expression.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,12 @@ def _handle_pipeline_input(
316316
_postfix = _update_postfix(_postfix, _name, _new_name)
317317
_expression_inputs[_new_name] = ExpressionInput(_new_name, _seen_input.type, _seen_input)
318318
_postfix.append(_name)
319+
320+
param_input = pipeline_inputs
321+
for group_name in _pipeline_input._group_names:
322+
param_input = param_input[group_name].values
319323
_expression_inputs[_name] = ExpressionInput(
320-
_name, pipeline_inputs[_pipeline_input._port_name].type, _pipeline_input
324+
_name, param_input[_pipeline_input._port_name].type, _pipeline_input
321325
)
322326
return _postfix, _expression_inputs
323327

sdk/ml/azure-ai-ml/tests/dsl/unittests/test_controlflow_pipeline.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from azure.ai.ml.dsl import pipeline
99
from azure.ai.ml.dsl._condition import condition
1010
from azure.ai.ml.dsl._do_while import do_while
11+
from azure.ai.ml.dsl._group_decorator import group
1112
from azure.ai.ml.dsl._parallel_for import parallel_for
1213
from azure.ai.ml.entities._builders.parallel_for import ParallelFor
1314
from azure.ai.ml.entities._job.pipeline._io import InputOutputBase, PipelineInput
@@ -208,6 +209,42 @@ def condition_pipeline():
208209
"result": {"_source": "YAML.COMPONENT", "name": "result", "type": "command"},
209210
}
210211

212+
def test_condition_with_group_input(self):
213+
hello_world_component_no_paths = load_component(
214+
source=r"./tests/test_configs/components/helloworld_component_no_paths.yml"
215+
)
216+
217+
@group
218+
class SubGroup:
219+
num: int
220+
221+
@group
222+
class ParentGroup:
223+
input_group: SubGroup
224+
225+
@pipeline(
226+
compute="cpu-cluster",
227+
)
228+
def condition_pipeline(group_input: ParentGroup):
229+
node1 = hello_world_component_no_paths(component_in_number=1)
230+
condition(condition=group_input.input_group.num < 100, true_block=node1)
231+
232+
pipeline_job = condition_pipeline(group_input=ParentGroup(input_group=SubGroup(num=10)))
233+
omit_fields = [
234+
"name",
235+
"properties.display_name",
236+
"properties.jobs.*.componentId",
237+
"properties.settings",
238+
]
239+
dsl_pipeline_job_dict = omit_with_wildcard(pipeline_job._to_rest_object().as_dict(), *omit_fields)
240+
assert dsl_pipeline_job_dict["properties"]["jobs"]["expression_component"] == {
241+
"environment_variables": {"AZURE_ML_CLI_PRIVATE_FEATURES_ENABLED": "true"},
242+
"name": "expression_component",
243+
"type": "command",
244+
"inputs": {"num": {"job_input_type": "literal", "value": "${{parent.inputs.group_input.input_group.num}}"}},
245+
"_source": "YAML.COMPONENT",
246+
}
247+
211248

212249
class TestDoWhilePipelineUT(TestControlFlowPipelineUT):
213250
def test_invalid_do_while_pipeline(self):

0 commit comments

Comments
 (0)