Skip to content

Commit 369dfd2

Browse files
authored
[ML][Pipeline] Support parameter group on pipeline node (Azure#26713)
* Skip check for node input Signed-off-by: Brynn Yin <[email protected]> * Support parameter group on pipeline component Signed-off-by: Brynn Yin <[email protected]> * Reformat Signed-off-by: Brynn Yin <[email protected]> * Fix lint Signed-off-by: Brynn Yin <[email protected]> Signed-off-by: Brynn Yin <[email protected]>
1 parent 8d20078 commit 369dfd2

File tree

9 files changed

+985
-39
lines changed

9 files changed

+985
-39
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_component_builder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,14 @@ def _map_type(_meta):
277277
return output_dict
278278

279279
def _get_group_parameter_defaults(self):
280-
return {key: copy.deepcopy(val.default) for key, val in self.inputs.items() if isinstance(val, GroupInput)}
280+
group_defaults = {}
281+
for key, val in self.inputs.items():
282+
if not isinstance(val, GroupInput):
283+
continue
284+
# Copy and insert top-level parameter name into group names for all items
285+
group_defaults[key] = copy.deepcopy(val.default)
286+
group_defaults[key].insert_group_name_for_items(key)
287+
return group_defaults
281288

282289
def _update_nodes_variable_names(self, func_variables: dict):
283290
"""Update nodes list to ordered dict with variable name key and

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_decorator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from azure.ai.ml.entities import Data, PipelineJob, PipelineJobSettings
1616
from azure.ai.ml.entities._builders.pipeline import Pipeline
1717
from azure.ai.ml.entities._inputs_outputs import Input, is_parameter_group
18+
from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput, _GroupAttrDict
19+
from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpression
1820
from azure.ai.ml.exceptions import (
1921
MissingPositionalArgsError,
2022
MultipleValueError,
@@ -23,8 +25,6 @@
2325
UnsupportedParameterKindError,
2426
UserErrorException,
2527
)
26-
from azure.ai.ml.entities._job.pipeline._io import NodeOutput, PipelineInput
27-
from azure.ai.ml.entities._job.pipeline._pipeline_expression import PipelineExpression
2828

2929
from ._pipeline_component_builder import PipelineComponentBuilder, _is_inside_dsl_pipeline_func
3030
from ._settings import _dsl_settings_stack
@@ -42,6 +42,8 @@
4242
bool,
4343
int,
4444
float,
45+
PipelineExpression,
46+
_GroupAttrDict,
4547
)
4648
module_logger = logging.getLogger(__name__)
4749

@@ -242,7 +244,7 @@ def _validate_args(func, args, kwargs):
242244

243245
def _is_supported_data_type(_data):
244246
return (
245-
isinstance(_data, SUPPORTED_INPUT_TYPES + (PipelineExpression,))
247+
isinstance(_data, SUPPORTED_INPUT_TYPES)
246248
or is_parameter_group(_data)
247249
)
248250

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/base_node.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from azure.ai.ml.entities._job.sweep.search_space import SweepDistribution
2424
from azure.ai.ml.entities._mixins import YamlTranslatableMixin
2525
from azure.ai.ml.entities._util import convert_ordered_dict_to_dict, resolve_pipeline_parameters
26-
from azure.ai.ml.entities._validation import SchemaValidatableMixin, MutableValidationResult
26+
from azure.ai.ml.entities._validation import MutableValidationResult, SchemaValidatableMixin
2727
from azure.ai.ml.exceptions import ErrorTarget, ValidationErrorType, ValidationException
2828

2929
module_logger = logging.getLogger(__name__)
@@ -231,7 +231,10 @@ def _parse_io(cls, io_dict: dict, parse_cls):
231231
value = value._deepcopy() # Decoupled input and output
232232
io_dict[key] = value
233233
value.mode = None
234-
elif isinstance(value, dict):
234+
elif type(value) == dict: # pylint: disable=unidiomatic-typecheck
235+
# Use type comparison instead of is_instance to skip _GroupAttrDict
236+
# when loading from yaml io will be a dict,
237+
# like {'job_data_path': '${{parent.inputs.pipeline_job_data_path}}'}
235238
# parse dict to allowed type
236239
io_dict[key] = parse_cls(**value)
237240

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_component/pipeline_component.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,13 @@ def _validate_binding_inputs(self, node: BaseNode) -> MutableValidationResult:
232232
not set. Raise error if pipeline input is optional but link to
233233
required inputs.
234234
"""
235-
component_definition_inputs = node.component.inputs
235+
component_definition_inputs = {}
236+
# Add flattened group input into definition inputs.
237+
# e.g. Add {'group_name.item': PipelineInput} for {'group_name': GroupInput}
238+
for name, val in node.component.inputs.items():
239+
if isinstance(val, GroupInput):
240+
component_definition_inputs.update(val.flatten(group_parameter_name=name))
241+
component_definition_inputs[name] = val
236242
# Collect binding relation dict {'pipeline_input': ['node_input']}
237243
validation_result = self._create_empty_validation_result()
238244
binding_dict, optional_binding_in_expression_dict = self._get_input_binding_dict(node)

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/group_input.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ---------------------------------------------------------
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
4-
4+
import copy
55
from enum import Enum as PyEnum
66

77
from azure.ai.ml.constants._component import IOConstants
@@ -38,9 +38,15 @@ def _create_default(self):
3838
from .._job.pipeline._io import PipelineInput
3939

4040
default_dict = {}
41+
# Note: no top-level group names at this time.
4142
for k, v in self.values.items():
42-
# Assign directly if is subgroup, else create PipelineInput object
43-
default_dict[k] = v.default if isinstance(v, GroupInput) else PipelineInput(name=k, data=v.default, meta=v)
43+
# Create PipelineInput object if not subgroup
44+
if not isinstance(v, GroupInput):
45+
default_dict[k] = PipelineInput(name=k, data=v.default, meta=v)
46+
continue
47+
# Copy and insert k into group names for subgroup
48+
default_dict[k] = copy.deepcopy(v.default)
49+
default_dict[k].insert_group_name_for_items(k)
4450
return self._create_group_attr_dict(default_dict)
4551

4652
@classmethod

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/_input_output_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from azure.ai.ml._restclient.v2022_10_01_preview.models import JobInput as RestJobInput
1313
from azure.ai.ml._restclient.v2022_10_01_preview.models import JobInputType
1414
from azure.ai.ml._restclient.v2022_10_01_preview.models import JobOutput as RestJobOutput
15-
from azure.ai.ml._restclient.v2022_10_01_preview.models import JobOutputType
16-
from azure.ai.ml._restclient.v2022_10_01_preview.models import LiteralJobInput
15+
from azure.ai.ml._restclient.v2022_10_01_preview.models import JobOutputType, LiteralJobInput
1716
from azure.ai.ml._restclient.v2022_10_01_preview.models import MLFlowModelJobInput as RestMLFlowModelJobInput
1817
from azure.ai.ml._restclient.v2022_10_01_preview.models import MLFlowModelJobOutput as RestMLFlowModelJobOutput
1918
from azure.ai.ml._restclient.v2022_10_01_preview.models import MLTableJobInput as RestMLTableJobInput
@@ -191,7 +190,8 @@ def to_rest_dataset_literal_inputs(
191190
for input_name, input_value in inputs.items():
192191
if job_type == JobType.PIPELINE:
193192
validate_pipeline_input_key_contains_allowed_characters(input_name)
194-
else:
193+
elif job_type:
194+
# We pass job_type=None for pipeline node, and want skip this check for nodes.
195195
validate_key_contains_allowed_characters(input_name)
196196
if isinstance(input_value, Input):
197197
if input_value.path and isinstance(input_value.path, str) and is_data_binding_expression(input_value.path):

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/pipeline/_io.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,6 @@ def __init__(self, name: str, meta: Input, group_names: List[str] = None, **kwar
438438
"""
439439
super(PipelineInput, self).__init__(name=name, meta=meta, **kwargs)
440440
self._group_names = group_names if group_names else []
441-
self._full_name = "%s.%s" % (".".join(self._group_names), self._name) if self._group_names else self._name
442441

443442
def __str__(self) -> str:
444443
return self._data_binding()
@@ -451,7 +450,7 @@ def _build_data(self, data, key=None): # pylint: disable=unused-argument
451450
msg = "Can not bind input to another component's input."
452451
raise ValidationException(message=msg, no_personal_data_message=msg, target=ErrorTarget.PIPELINE)
453452
if isinstance(data, (PipelineInput, NodeOutput)):
454-
# If value is input or output, it's a data binding, we require it have a owner so we can convert it to
453+
# If value is input or output, it's a data binding, owner is required to convert it to
455454
# a data binding, eg: ${{parent.inputs.xxx}}
456455
if isinstance(data, NodeOutput) and data._owner is None:
457456
msg = "Setting input binding {} to output without owner is not allowed."
@@ -468,7 +467,8 @@ def _build_data(self, data, key=None): # pylint: disable=unused-argument
468467
return data
469468

470469
def _data_binding(self):
471-
return f"${{{{parent.inputs.{self._full_name}}}}}"
470+
full_name = "%s.%s" % (".".join(self._group_names), self._name) if self._group_names else self._name
471+
return f"${{{{parent.inputs.{full_name}}}}}"
472472

473473
def _to_input(self) -> Input:
474474
"""Convert pipeline input to component input for pipeline component."""
@@ -628,6 +628,15 @@ def flatten(self, group_parameter_name):
628628
)
629629
return flattened_parameters
630630

631+
def insert_group_name_for_items(self, group_name):
632+
# Insert one group name for all items.
633+
for v in self.values():
634+
if isinstance(v, _GroupAttrDict):
635+
v.insert_group_name_for_items(group_name)
636+
elif isinstance(v, PipelineInput):
637+
# Insert group names for pipeline input
638+
v._group_names = [group_name] + v._group_names
639+
631640

632641
class OutputsAttrDict(dict):
633642
def __init__(self, outputs: dict, **kwargs):

sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_dsl_pipeline.py

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77

88
import pydash
99
import pytest
10-
from pipeline_job.e2etests.test_pipeline_job import assert_job_input_output_types
11-
from test_utilities.utils import _PYTEST_TIMEOUT_METHOD, omit_with_wildcard
12-
1310
from azure.ai.ml import (
1411
Input,
1512
MLClient,
@@ -25,18 +22,21 @@
2522
from azure.ai.ml.constants._common import AssetTypes, InputOutputModes
2623
from azure.ai.ml.constants._job.pipeline import PipelineConstants
2724
from azure.ai.ml.dsl._load_import import to_component
25+
from azure.ai.ml.dsl._parameter_group_decorator import parameter_group
2826
from azure.ai.ml.entities import CommandComponent, CommandJob
2927
from azure.ai.ml.entities import Component
3028
from azure.ai.ml.entities import Component as ComponentEntity
3129
from azure.ai.ml.entities import Data, PipelineJob
3230
from azure.ai.ml.exceptions import ValidationException
3331
from azure.ai.ml.parallel import ParallelJob, RunFunction, parallel_run_function
32+
from azure.core.exceptions import HttpResponseError
3433
from azure.core.polling import LROPoller
34+
from devtools_testutils import AzureRecordedTestCase
35+
from pipeline_job.e2etests.test_pipeline_job import assert_job_input_output_types
36+
from test_utilities.utils import _PYTEST_TIMEOUT_METHOD, omit_with_wildcard
3537

3638
from .._util import _DSL_TIMEOUT_SECOND
3739

38-
from devtools_testutils import AzureRecordedTestCase
39-
4040
tests_root_dir = Path(__file__).parent.parent.parent
4141
components_dir = tests_root_dir / "test_configs/components/"
4242
job_input = Input(
@@ -57,6 +57,17 @@
5757
]
5858

5959

60+
def assert_job_cancel(pipeline, client: MLClient):
61+
job = client.jobs.create_or_update(pipeline)
62+
try:
63+
cancel_poller = client.jobs.begin_cancel(job.name)
64+
assert isinstance(cancel_poller, LROPoller)
65+
assert cancel_poller.result() is None
66+
except HttpResponseError:
67+
pass
68+
return job
69+
70+
6071
@pytest.mark.usefixtures(
6172
"enable_environment_id_arm_expansion",
6273
"enable_pipeline_private_preview_features",
@@ -1058,7 +1069,6 @@ def valid_pipeline_func(
10581069
required_input: Input,
10591070
required_param: str,
10601071
node_compute: str = "cpu-cluster",
1061-
# node_compute: str = 'azureml:cpu-cluster', # both will be supported
10621072
):
10631073
default_optional_func(
10641074
required_input=required_input,
@@ -1090,6 +1100,64 @@ def valid_pipeline_func(
10901100
in caplog.messages
10911101
)
10921102

1103+
def test_create_pipeline_with_parameter_group(self, client: MLClient) -> None:
1104+
default_optional_func = load_component(source=str(components_dir / "default_optional_component.yml"))
1105+
1106+
@parameter_group
1107+
class SubGroup:
1108+
required_param: str
1109+
1110+
@parameter_group
1111+
class Group:
1112+
sub: SubGroup
1113+
node_compute: str = "cpu-cluster"
1114+
1115+
@dsl.pipeline()
1116+
def sub_pipeline_func(
1117+
required_input: Input,
1118+
group: Group,
1119+
sub_group: SubGroup,
1120+
):
1121+
default_optional_func(
1122+
required_input=required_input,
1123+
required_param=group.sub.required_param,
1124+
)
1125+
node2 = default_optional_func(
1126+
required_input=required_input,
1127+
required_param=sub_group.required_param,
1128+
)
1129+
node2.compute = group.node_compute
1130+
1131+
@dsl.pipeline(default_compute="cpu-cluster")
1132+
def root_pipeline_with_group(
1133+
r_required_input: Input,
1134+
r_group: Group,
1135+
):
1136+
sub_pipeline_func(required_input=r_required_input, group=r_group, sub_group=r_group.sub)
1137+
1138+
job = root_pipeline_with_group(
1139+
r_required_input=Input(type="uri_file", path="https://dprepdata.blob.core.windows.net/demo/Titanic.csv"),
1140+
r_group=Group(sub=SubGroup(required_param="hello")),
1141+
)
1142+
rest_job = assert_job_cancel(job, client)
1143+
assert len(rest_job.inputs) == 2
1144+
rest_job_dict = rest_job._to_dict()
1145+
assert rest_job_dict["inputs"] == {
1146+
"r_required_input": {
1147+
"mode": "ro_mount",
1148+
"type": "uri_file",
1149+
"path": "azureml:https://dprepdata.blob.core.windows.net/demo/Titanic.csv",
1150+
},
1151+
"r_group.sub.required_param": "hello",
1152+
"r_group.node_compute": "cpu-cluster",
1153+
}
1154+
assert rest_job_dict["jobs"]["sub_pipeline_func"]["inputs"] == {
1155+
"required_input": {"path": "${{parent.inputs.r_required_input}}"},
1156+
"group.sub.required_param": {"path": "${{parent.inputs.r_group.sub.required_param}}"},
1157+
"group.node_compute": {"path": "${{parent.inputs.r_group.node_compute}}"},
1158+
"sub_group.required_param": {"path": "${{parent.inputs.r_group.sub.required_param}}"},
1159+
}
1160+
10931161
def test_pipeline_with_none_parameter_has_default_optional_true(self, client: MLClient) -> None:
10941162
default_optional_func = load_component(source=str(components_dir / "default_optional_component.yml"))
10951163

@@ -1525,9 +1593,7 @@ def parallel_in_pipeline(job_data_path, score_model):
15251593
assert_job_input_output_types(pipeline_job)
15261594
assert pipeline_job.settings.default_compute == "cpu-cluster"
15271595

1528-
@pytest.mark.skip(
1529-
"https://dev.azure.com/msdata/Vienna/_workitems/edit/2009659"
1530-
)
1596+
@pytest.mark.skip("https://dev.azure.com/msdata/Vienna/_workitems/edit/2009659")
15311597
def test_parallel_components_with_file_input(self, client: MLClient) -> None:
15321598
components_dir = tests_root_dir / "test_configs/dsl_pipeline/parallel_component_with_file_input"
15331599

@@ -1958,11 +2024,11 @@ def pipeline(job_in_number, job_in_other_number, job_in_path):
19582024
client.jobs.get(child.name)
19592025
client.jobs.get(child.name)._repr_html_()
19602026

1961-
@pytest.mark.skip(
1962-
"https://dev.azure.com/msdata/Vienna/_workitems/edit/2009659"
1963-
)
2027+
@pytest.mark.skip("https://dev.azure.com/msdata/Vienna/_workitems/edit/2009659")
19642028
def test_dsl_pipeline_without_setting_binding_node(self, client: MLClient) -> None:
1965-
from test_configs.dsl_pipeline.pipeline_with_set_binding_output_input.pipeline import pipeline_without_setting_binding_node
2029+
from test_configs.dsl_pipeline.pipeline_with_set_binding_output_input.pipeline import (
2030+
pipeline_without_setting_binding_node,
2031+
)
19662032

19672033
pipeline = pipeline_without_setting_binding_node()
19682034
pipeline_job = client.jobs.create_or_update(pipeline)
@@ -2011,9 +2077,7 @@ def test_dsl_pipeline_without_setting_binding_node(self, client: MLClient) -> No
20112077
}
20122078
assert expected_job == actual_job
20132079

2014-
@pytest.mark.skip(
2015-
"https://dev.azure.com/msdata/Vienna/_workitems/edit/2009659"
2016-
)
2080+
@pytest.mark.skip("https://dev.azure.com/msdata/Vienna/_workitems/edit/2009659")
20172081
def test_dsl_pipeline_with_only_setting_pipeline_level(self, client: MLClient) -> None:
20182082
from test_configs.dsl_pipeline.pipeline_with_set_binding_output_input.pipeline import (
20192083
pipeline_with_only_setting_pipeline_level,
@@ -2066,12 +2130,12 @@ def test_dsl_pipeline_with_only_setting_pipeline_level(self, client: MLClient) -
20662130
}
20672131
assert expected_job == actual_job
20682132

2069-
@pytest.mark.skip(
2070-
"https://dev.azure.com/msdata/Vienna/_workitems/edit/2009659"
2071-
)
2133+
@pytest.mark.skip("https://dev.azure.com/msdata/Vienna/_workitems/edit/2009659")
20722134
def test_dsl_pipeline_with_only_setting_binding_node(self, client: MLClient) -> None:
20732135
# Todo: checkout run priority when backend is ready
2074-
from test_configs.dsl_pipeline.pipeline_with_set_binding_output_input.pipeline import pipeline_with_only_setting_binding_node
2136+
from test_configs.dsl_pipeline.pipeline_with_set_binding_output_input.pipeline import (
2137+
pipeline_with_only_setting_binding_node,
2138+
)
20752139

20762140
pipeline = pipeline_with_only_setting_binding_node()
20772141
pipeline_job = client.jobs.create_or_update(pipeline)
@@ -2130,9 +2194,7 @@ def test_dsl_pipeline_with_only_setting_binding_node(self, client: MLClient) ->
21302194
}
21312195
assert expected_job == actual_job
21322196

2133-
@pytest.mark.skip(
2134-
"https://dev.azure.com/msdata/Vienna/_workitems/edit/2009659"
2135-
)
2197+
@pytest.mark.skip("https://dev.azure.com/msdata/Vienna/_workitems/edit/2009659")
21362198
def test_dsl_pipeline_with_setting_binding_node_and_pipeline_level(self, client: MLClient) -> None:
21372199
from test_configs.dsl_pipeline.pipeline_with_set_binding_output_input.pipeline import (
21382200
pipeline_with_setting_binding_node_and_pipeline_level,

0 commit comments

Comments
 (0)