Skip to content

Commit f317ab3

Browse files
authored
[NOT READY FOR MERGE] feat: support internal spark component (Azure#29271)
* feat: support internal spark component * feat: support conda file * refactor: only load internal Spark component as internal * refactor: inherit serialization logic from 3p spark component * feat: register internal spark component as public spark component (environment will be lost for now
1 parent 094cd29 commit f317ab3

39 files changed

+4366
-287
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/command.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
# ---------------------------------------------------------
44
from marshmallow import fields
55

6-
from azure.ai.ml._internal._schema.node import InternalBaseNodeSchema, NodeType
7-
from azure.ai.ml._schema import AnonymousEnvironmentSchema, ArmVersionedStr, NestedField, RegistryStr, UnionField
8-
from azure.ai.ml._schema.core.fields import DumpableEnumField
9-
from azure.ai.ml._schema.job import ParameterizedCommandSchema, ParameterizedParallelSchema
10-
from azure.ai.ml._schema.job.job_limits import CommandJobLimitsSchema
11-
from azure.ai.ml.constants._common import AzureMLResourceType
6+
from ..._schema import AnonymousEnvironmentSchema, ArmVersionedStr, NestedField, RegistryStr, UnionField
7+
from ..._schema.core.fields import DumpableEnumField
8+
from ..._schema.job import ParameterizedCommandSchema, ParameterizedParallelSchema
9+
from ..._schema.job.job_limits import CommandJobLimitsSchema
10+
from ...constants._common import AzureMLResourceType
11+
from .._schema.node import InternalBaseNodeSchema, NodeType
1212

1313

1414
class CommandSchema(InternalBaseNodeSchema, ParameterizedCommandSchema):

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/component.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
import pydash
77
from marshmallow import EXCLUDE, INCLUDE, fields, post_dump, pre_load
88

9-
from azure.ai.ml._schema import NestedField, StringTransformedEnum, UnionField
10-
from azure.ai.ml._schema.component.component import ComponentSchema
11-
from azure.ai.ml._schema.core.fields import ArmVersionedStr, CodeField
12-
from azure.ai.ml.constants._common import LABELLED_RESOURCE_NAME, AzureMLResourceType, SOURCE_PATH_CONTEXT_KEY
13-
14-
from .._utils import yaml_safe_load_with_base_resolver
9+
from ..._schema import AnonymousEnvironmentSchema, NestedField, StringTransformedEnum, UnionField
10+
from ..._schema.component.component import ComponentSchema
11+
from ..._schema.core.fields import ArmVersionedStr, CodeField, RegistryStr
12+
from ..._schema.job.parameterized_spark import SparkConfSchema, SparkEntryClassSchema, SparkEntryFileSchema
1513
from ..._utils._arm_id_utils import parse_name_label
1614
from ..._utils.utils import get_valid_dot_keys_with_wildcard
15+
from ...constants._common import LABELLED_RESOURCE_NAME, SOURCE_PATH_CONTEXT_KEY, AzureMLResourceType
16+
from ...constants._component import NodeType as PublicNodeType
17+
from .._utils import yaml_safe_load_with_base_resolver
1718
from .environment import InternalEnvironmentSchema
1819
from .input_output import (
1920
InternalEnumParameterSchema,
@@ -37,6 +38,9 @@ class NodeType:
3738
HEMERA = "HemeraComponent"
3839
AE365EXEPOOL = "AE365ExePoolComponent"
3940
IPP = "IntellectualPropertyProtectedComponent"
41+
# internal spake component got a type value conflict with spark component
42+
# this enum is used to identify its create_function in factories
43+
SPARK = "DummySpark"
4044

4145
@classmethod
4246
def all_values(cls):
@@ -161,3 +165,49 @@ def add_back_type_label(self, data, original, **kwargs): # pylint:disable=unuse
161165
if type_label:
162166
data["type"] = LABELLED_RESOURCE_NAME.format(data["type"], type_label)
163167
return data
168+
169+
170+
class InternalSparkComponentSchema(InternalComponentSchema):
171+
# type field is required for registration
172+
type = StringTransformedEnum(
173+
allowed_values=PublicNodeType.SPARK,
174+
casing_transform=lambda x: parse_name_label(x)[0].lower(),
175+
pass_original=True,
176+
)
177+
178+
environment = UnionField(
179+
[
180+
# unlike other internal component, internal spark component do not use internal environment schema
181+
NestedField(AnonymousEnvironmentSchema),
182+
RegistryStr(azureml_type=AzureMLResourceType.ENVIRONMENT),
183+
ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT, allow_default_version=True),
184+
NestedField(InternalEnvironmentSchema),
185+
],
186+
allow_none=True,
187+
)
188+
189+
jars = UnionField(
190+
[
191+
fields.List(fields.Str()),
192+
fields.Str(),
193+
],
194+
)
195+
py_files = UnionField(
196+
[
197+
fields.List(fields.Str()),
198+
fields.Str(),
199+
],
200+
data_key="pyFiles",
201+
attribute="py_files",
202+
)
203+
204+
entry = UnionField(
205+
[NestedField(SparkEntryFileSchema), NestedField(SparkEntryClassSchema)],
206+
required=True,
207+
metadata={"description": "Entry."},
208+
)
209+
210+
files = fields.List(fields.Str(required=True))
211+
archives = fields.List(fields.Str(required=True))
212+
conf = NestedField(SparkConfSchema, unknown=INCLUDE)
213+
args = fields.Str(metadata={"description": "Command Line arguments."})

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/environment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
from marshmallow import fields
66

7-
from azure.ai.ml._schema import PathAwareSchema
8-
from azure.ai.ml._schema.core.fields import DumpableEnumField, VersionField
7+
from ..._schema import PathAwareSchema
8+
from ..._schema.core.fields import DumpableEnumField, VersionField
99

1010

1111
class InternalEnvironmentSchema(PathAwareSchema):

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
from marshmallow import fields, post_dump, post_load
66

7-
from azure.ai.ml._schema import PatchedSchemaMeta, StringTransformedEnum, UnionField
8-
from azure.ai.ml._schema.component.input_output import InputPortSchema, ParameterSchema
9-
from azure.ai.ml._schema.core.fields import DumpableEnumField, PrimitiveValueField
7+
from ..._schema import PatchedSchemaMeta, StringTransformedEnum, UnionField
8+
from ..._schema.component.input_output import InputPortSchema, ParameterSchema
9+
from ..._schema.core.fields import DumpableEnumField, PrimitiveValueField
1010

1111
SUPPORTED_INTERNAL_PARAM_TYPES = [
1212
"integer",

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/node.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55
from marshmallow import INCLUDE, fields, post_load, pre_dump
66

7-
from azure.ai.ml._schema import ArmVersionedStr, NestedField, RegistryStr, UnionField
8-
from azure.ai.ml._schema.pipeline.component_job import BaseNodeSchema, _resolve_inputs_outputs
9-
from azure.ai.ml.constants._common import AzureMLResourceType
10-
7+
from ..._schema import ArmVersionedStr, NestedField, RegistryStr, UnionField
118
from ..._schema.core.fields import DumpableEnumField
9+
from ..._schema.pipeline.component_job import BaseNodeSchema, _resolve_inputs_outputs
10+
from ...constants._common import AzureMLResourceType
1211
from .component import InternalComponentSchema, NodeType
1312

1413

@@ -33,13 +32,13 @@ class Meta:
3332

3433
@post_load
3534
def make(self, data, **kwargs): # pylint: disable=unused-argument, no-self-use
36-
from azure.ai.ml.entities._builders import parse_inputs_outputs
35+
from ...entities._builders import parse_inputs_outputs
3736

3837
# parse inputs/outputs
3938
data = parse_inputs_outputs(data)
4039

4140
# dict to node object
42-
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory
41+
from ...entities._job.pipeline._load_component import pipeline_node_factory
4342

4443
return pipeline_node_factory.load_from_dict(data=data)
4544

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_setup.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
# ---------------------------------------------------------
44

55
# pylint: disable=protected-access
6-
76
from marshmallow import INCLUDE
87

9-
from azure.ai.ml._internal._schema.command import CommandSchema, DistributedSchema, ParallelSchema
10-
from azure.ai.ml._internal._schema.component import NodeType
11-
from azure.ai.ml._internal._schema.node import HDInsightSchema, InternalBaseNodeSchema, ScopeSchema
12-
from azure.ai.ml._internal.entities import (
8+
from .._schema import NestedField
9+
from ..entities._component.component_factory import component_factory
10+
from ..entities._job.pipeline._load_component import pipeline_node_factory
11+
from ._schema.command import CommandSchema, DistributedSchema, ParallelSchema
12+
from ._schema.component import NodeType
13+
from ._schema.node import HDInsightSchema, InternalBaseNodeSchema, ScopeSchema
14+
from .entities import (
1315
Command,
1416
DataTransfer,
1517
Distributed,
@@ -21,9 +23,7 @@
2123
Scope,
2224
Starlite,
2325
)
24-
from azure.ai.ml._schema import NestedField
25-
from azure.ai.ml.entities._component.component_factory import component_factory
26-
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory
26+
from .entities.spark import InternalSparkComponent
2727

2828
_registered = False
2929

@@ -41,6 +41,11 @@ def _enable_internal_components():
4141
create_instance_func=lambda: InternalComponent.__new__(InternalComponent),
4242
create_schema_func=create_schema_func,
4343
)
44+
component_factory.register_type(
45+
_type=NodeType.SPARK,
46+
create_instance_func=lambda: InternalSparkComponent.__new__(InternalSparkComponent),
47+
create_schema_func=InternalSparkComponent._create_schema_for_validation,
48+
)
4449

4550

4651
def _register_node(_type, node_cls, schema_cls):

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/_additional_includes.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111

1212
import yaml
1313

14-
from azure.ai.ml._utils._asset_utils import IgnoreFile, traverse_directory
15-
from azure.ai.ml.entities._util import _general_copy
16-
from azure.ai.ml.entities._validation import MutableValidationResult, _ValidationResultBuilder
17-
14+
from ..._utils._asset_utils import IgnoreFile, traverse_directory
15+
from ...entities._util import _general_copy
16+
from ...entities._validation import MutableValidationResult, _ValidationResultBuilder
1817
from ._artifact_cache import ArtifactCache
1918
from .code import InternalComponentIgnoreFile
2019

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/_input_outputs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
# ---------------------------------------------------------
44
from typing import Dict, Optional, Union
55

6-
from azure.ai.ml import Input, Output
7-
from azure.ai.ml._internal._schema.input_output import SUPPORTED_INTERNAL_PARAM_TYPES
8-
from azure.ai.ml._utils.utils import get_all_enum_values_iter
9-
from azure.ai.ml.constants import AssetTypes
10-
from azure.ai.ml.constants._common import InputTypes
11-
from azure.ai.ml.constants._component import ComponentParameterTypes, IOConstants
6+
from ... import Input, Output
7+
from ..._utils.utils import get_all_enum_values_iter
8+
from ...constants import AssetTypes
9+
from ...constants._common import InputTypes
10+
from ...constants._component import ComponentParameterTypes, IOConstants
11+
from .._schema.input_output import SUPPORTED_INTERNAL_PARAM_TYPES
1212

1313
_INPUT_TYPE_ENUM = "enum"
1414
_INPUT_TYPE_ENUM_CAP = "Enum"

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/command.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66

77
from marshmallow import INCLUDE, Schema
88

9-
from azure.ai.ml import MpiDistribution, PyTorchDistribution, TensorFlowDistribution
10-
from azure.ai.ml._internal._schema.component import NodeType
11-
from azure.ai.ml._internal.entities.component import InternalComponent
12-
from azure.ai.ml._internal.entities.node import InternalBaseNode
13-
from azure.ai.ml._restclient.v2023_02_01_preview.models import CommandJobLimits as RestCommandJobLimits
14-
from azure.ai.ml._restclient.v2023_02_01_preview.models import JobResourceConfiguration as RestJobResourceConfiguration
15-
from azure.ai.ml._schema import PathAwareSchema
16-
from azure.ai.ml._schema.core.fields import DistributionField
17-
from azure.ai.ml.entities import CommandJobLimits, JobResourceConfiguration
18-
from azure.ai.ml.entities._util import get_rest_dict_for_node_attrs
9+
from ... import MpiDistribution, PyTorchDistribution, TensorFlowDistribution
10+
from ..._restclient.v2023_02_01_preview.models import CommandJobLimits as RestCommandJobLimits
11+
from ..._restclient.v2023_02_01_preview.models import JobResourceConfiguration as RestJobResourceConfiguration
12+
from ..._schema import PathAwareSchema
13+
from ..._schema.core.fields import DistributionField
14+
from ...entities import CommandJobLimits, JobResourceConfiguration
15+
from ...entities._util import get_rest_dict_for_node_attrs
16+
from .._schema.component import NodeType
17+
from ..entities.component import InternalComponent
18+
from ..entities.node import InternalBaseNode
1919

2020

2121
class Command(InternalBaseNode):

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/component.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,19 @@
1111

1212
from marshmallow import Schema
1313

14-
from azure.ai.ml._restclient.v2022_05_01.models import ComponentVersionData, ComponentVersionDetails
15-
from azure.ai.ml._schema import PathAwareSchema
16-
from azure.ai.ml.entities import Component
17-
from azure.ai.ml.entities._system_data import SystemData
18-
from azure.ai.ml.entities._util import convert_ordered_dict_to_dict
19-
from azure.ai.ml.entities._validation import MutableValidationResult
20-
2114
from ... import Input, Output
15+
from ..._restclient.v2022_05_01.models import ComponentVersionData, ComponentVersionDetails
16+
from ..._schema import PathAwareSchema
2217
from ..._utils._arm_id_utils import parse_name_label
2318
from ..._utils._asset_utils import IgnoreFile
19+
from ...entities import Component
2420
from ...entities._assets import Code
2521
from ...entities._job.distribution import DistributionConfiguration
22+
from ...entities._system_data import SystemData
23+
from ...entities._util import convert_ordered_dict_to_dict
24+
from ...entities._validation import MutableValidationResult
2625
from .._schema.component import InternalComponentSchema
27-
from ._additional_includes import _AdditionalIncludes, ADDITIONAL_INCLUDES_SUFFIX
26+
from ._additional_includes import ADDITIONAL_INCLUDES_SUFFIX, _AdditionalIncludes
2827
from ._input_outputs import InternalInput, InternalOutput
2928
from ._merkle_tree import create_merkletree
3029
from .code import InternalCode, InternalComponentIgnoreFile

0 commit comments

Comments
 (0)