Skip to content

Commit a0b7966

Browse files
authored
Adding additional include support in spark component (Azure#38537)
Support additional include in spark component
1 parent c812e66 commit a0b7966

File tree

5 files changed

+98
-3
lines changed

5 files changed

+98
-3
lines changed

sdk/ml/azure-ai-ml/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
## 1.23.0 (unreleased)
33

44
### Features Added
5+
- Add support for additional include in spark component.
56

67
### Bugs Fixed
78

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/spark_component.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from copy import deepcopy
88

99
import yaml
10-
from marshmallow import INCLUDE, fields, post_load
10+
from marshmallow import INCLUDE, fields, post_dump, post_load
1111

1212
from azure.ai.ml._schema.assets.asset import AnonymousAssetSchema
1313
from azure.ai.ml._schema.component.component import ComponentSchema
@@ -20,6 +20,16 @@
2020

2121
class SparkComponentSchema(ComponentSchema, ParameterizedSparkSchema):
2222
type = StringTransformedEnum(allowed_values=[NodeType.SPARK])
23+
additional_includes = fields.List(fields.Str())
24+
25+
@post_dump
26+
def remove_unnecessary_fields(self, component_schema_dict, **kwargs):
27+
if (
28+
component_schema_dict.get("additional_includes") is not None
29+
and len(component_schema_dict["additional_includes"]) == 0
30+
):
31+
component_schema_dict.pop("additional_includes")
32+
return component_schema_dict
2333

2434

2535
class RestSparkComponentSchema(SparkComponentSchema):

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_component/spark_component.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
from .._job.spark_job_entry_mixin import SparkJobEntry, SparkJobEntryMixin
1818
from .._util import convert_ordered_dict_to_dict, validate_attribute_type
1919
from .._validation import MutableValidationResult
20-
from .code import ComponentCodeMixin
20+
from ._additional_includes import AdditionalIncludesMixin
2121
from .component import Component
2222

2323

2424
class SparkComponent(
25-
Component, ParameterizedSpark, SparkJobEntryMixin, ComponentCodeMixin
25+
Component, ParameterizedSpark, SparkJobEntryMixin, AdditionalIncludesMixin
2626
): # pylint: disable=too-many-instance-attributes
2727
"""Spark component version, used to define a Spark Component or Job.
2828
@@ -79,6 +79,8 @@ class SparkComponent(
7979
:paramtype outputs: Optional[dict[str, Union[str, ~azure.ai.ml.Output]]]
8080
:keyword args: The arguments for the job. Defaults to None.
8181
:paramtype args: Optional[str]
82+
:keyword additional_includes: A list of shared additional files to be included in the component. Defaults to None.
83+
:paramtype additional_includes: Optional[List[str]]
8284
8385
.. admonition:: Example:
8486
@@ -112,6 +114,7 @@ def __init__(
112114
inputs: Optional[Dict] = None,
113115
outputs: Optional[Dict] = None,
114116
args: Optional[str] = None,
117+
additional_includes: Optional[List] = None,
115118
**kwargs: Any,
116119
) -> None:
117120
# validate init params are valid type
@@ -134,6 +137,7 @@ def __init__(
134137
self.conf = conf
135138
self.environment = environment
136139
self.args = args
140+
self.additional_includes = additional_includes or []
137141
# For pipeline spark job, we also allow user to set driver_cores, driver_memory and so on by setting conf.
138142
# If root level fields are not set by user, we promote conf setting to root level to facilitate subsequent
139143
# verification. This usually happens when we use to_component(SparkJob) or builder function spark() as a node

sdk/ml/azure-ai-ml/tests/component/unittests/test_spark_component_entity.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@ def test_component_load(self):
2929
}
3030
assert spark_component.args == "--file_input ${{inputs.file_input}}"
3131

32+
def test_component_load_with_additional_include(self):
33+
# code is specified in yaml, value is respected
34+
component_yaml = "./tests/test_configs/components/hello_spark_component_with_additional_include.yml"
35+
spark_component = load_component(
36+
component_yaml,
37+
)
38+
39+
assert (
40+
isinstance(spark_component.additional_includes, list)
41+
and spark_component.additional_includes[0] == "common_src"
42+
)
43+
3244
def test_spark_component_to_dict(self):
3345
# Test optional params exists in component dict
3446
yaml_path = "./tests/test_configs/dsl_pipeline/spark_job_in_pipeline/add_greeting_column_component.yml"
@@ -37,6 +49,14 @@ def test_spark_component_to_dict(self):
3749
spark_component = SparkComponent._load(data=yaml_dict, yaml_path=yaml_path)
3850
assert spark_component._other_parameter.get("mock_option_param") == yaml_dict["mock_option_param"]
3951

52+
def test_spark_component_to_dict_additional_include(self):
53+
# Test optional params exists in component dict
54+
yaml_path = "./tests/test_configs/dsl_pipeline/spark_job_in_pipeline/add_greeting_column_component.yml"
55+
yaml_dict = load_yaml(yaml_path)
56+
yaml_dict["additional_includes"] = ["common_src"]
57+
spark_component = SparkComponent._load(data=yaml_dict, yaml_path=yaml_path)
58+
assert spark_component.additional_includes[0] == yaml_dict["additional_includes"][0]
59+
4060
def test_spark_component_entity(self):
4161
component = SparkComponent(
4262
name="add_greeting_column_spark_component",
@@ -73,6 +93,39 @@ def test_spark_component_entity(self):
7393

7494
assert component_dict == yaml_component_dict
7595

96+
def test_spark_component_entity_additional_include(self):
97+
component = SparkComponent(
98+
name="wordcount_spark_component",
99+
display_name="Spark word count",
100+
description="Spark word count",
101+
version="3",
102+
inputs={
103+
"file_input": {"type": "uri_file", "mode": "direct"},
104+
},
105+
driver_cores=1,
106+
driver_memory="2g",
107+
executor_cores=2,
108+
executor_memory="2g",
109+
executor_instances=4,
110+
entry={"file": "wordcount.py"},
111+
args="--input1 ${{inputs.file_input}}",
112+
base_path="./tests/test_configs/components",
113+
additional_includes=["common_src"],
114+
)
115+
omit_fields = [
116+
"properties.component_spec.$schema",
117+
"properties.component_spec._source",
118+
"properties.properties.client_component_hash",
119+
]
120+
component_dict = component._to_rest_object().as_dict()
121+
component_dict = pydash.omit(component_dict, *omit_fields)
122+
123+
yaml_path = "./tests/test_configs/components/hello_spark_component_with_additional_include.yml"
124+
yaml_component = load_component(yaml_path)
125+
yaml_component_dict = yaml_component._to_rest_object().as_dict()
126+
yaml_component_dict = pydash.omit(yaml_component_dict, *omit_fields)
127+
assert component_dict == yaml_component_dict
128+
76129
def test_spark_component_version_as_a_function_with_inputs(self):
77130
expected_rest_component = {
78131
"type": "spark",
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
$schema: https://azuremlschemas.azureedge.net/latest/sparkComponent.schema.json
2+
name: wordcount_spark_component
3+
type: spark
4+
version: 3
5+
display_name: Spark word count
6+
description: Spark word count
7+
8+
9+
inputs:
10+
file_input:
11+
type: uri_file
12+
mode: direct
13+
14+
entry:
15+
file: wordcount.py
16+
17+
args: >-
18+
--input1 ${{inputs.file_input}}
19+
20+
conf:
21+
spark.driver.cores: 1
22+
spark.driver.memory: "2g"
23+
spark.executor.cores: 2
24+
spark.executor.memory: "2g"
25+
spark.executor.instances: 4
26+
additional_includes:
27+
- common_src

0 commit comments

Comments
 (0)