|
8 | 8 | from azure.ai.ml.dsl import pipeline
|
9 | 9 | from azure.ai.ml.dsl._condition import condition
|
10 | 10 | from azure.ai.ml.dsl._do_while import do_while
|
| 11 | +from azure.ai.ml.dsl._group_decorator import group |
11 | 12 | from azure.ai.ml.dsl._parallel_for import parallel_for
|
12 | 13 | from azure.ai.ml.entities._builders.parallel_for import ParallelFor
|
13 | 14 | from azure.ai.ml.entities._job.pipeline._io import InputOutputBase, PipelineInput
|
@@ -208,6 +209,42 @@ def condition_pipeline():
|
208 | 209 | "result": {"_source": "YAML.COMPONENT", "name": "result", "type": "command"},
|
209 | 210 | }
|
210 | 211 |
|
| 212 | + def test_condition_with_group_input(self): |
| 213 | + hello_world_component_no_paths = load_component( |
| 214 | + source=r"./tests/test_configs/components/helloworld_component_no_paths.yml" |
| 215 | + ) |
| 216 | + |
| 217 | + @group |
| 218 | + class SubGroup: |
| 219 | + num: int |
| 220 | + |
| 221 | + @group |
| 222 | + class ParentGroup: |
| 223 | + input_group: SubGroup |
| 224 | + |
| 225 | + @pipeline( |
| 226 | + compute="cpu-cluster", |
| 227 | + ) |
| 228 | + def condition_pipeline(group_input: ParentGroup): |
| 229 | + node1 = hello_world_component_no_paths(component_in_number=1) |
| 230 | + condition(condition=group_input.input_group.num < 100, true_block=node1) |
| 231 | + |
| 232 | + pipeline_job = condition_pipeline(group_input=ParentGroup(input_group=SubGroup(num=10))) |
| 233 | + omit_fields = [ |
| 234 | + "name", |
| 235 | + "properties.display_name", |
| 236 | + "properties.jobs.*.componentId", |
| 237 | + "properties.settings", |
| 238 | + ] |
| 239 | + dsl_pipeline_job_dict = omit_with_wildcard(pipeline_job._to_rest_object().as_dict(), *omit_fields) |
| 240 | + assert dsl_pipeline_job_dict["properties"]["jobs"]["expression_component"] == { |
| 241 | + "environment_variables": {"AZURE_ML_CLI_PRIVATE_FEATURES_ENABLED": "true"}, |
| 242 | + "name": "expression_component", |
| 243 | + "type": "command", |
| 244 | + "inputs": {"num": {"job_input_type": "literal", "value": "${{parent.inputs.group_input.input_group.num}}"}}, |
| 245 | + "_source": "YAML.COMPONENT", |
| 246 | + } |
| 247 | + |
211 | 248 |
|
212 | 249 | class TestDoWhilePipelineUT(TestControlFlowPipelineUT):
|
213 | 250 | def test_invalid_do_while_pipeline(self):
|
|
0 commit comments