|
31 | 31 | from io import BytesIO |
32 | 32 | from urllib.parse import urlparse |
33 | 33 |
|
| 34 | +from typing import Union, List, Dict, Optional |
| 35 | + |
34 | 36 | from sagemaker import image_uris |
35 | 37 | from sagemaker.local.image import _ecr_login_if_needed, _pull_image |
36 | 38 | from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor |
37 | 39 | from sagemaker.s3 import S3Uploader |
38 | 40 | from sagemaker.session import Session |
| 41 | +from sagemaker.network import NetworkConfig |
39 | 42 | from sagemaker.spark import defaults |
40 | 43 |
|
| 44 | +from sagemaker.workflow import is_pipeline_variable |
| 45 | +from sagemaker.workflow.entities import PipelineVariable |
| 46 | +from sagemaker.workflow.functions import Join |
| 47 | + |
41 | 48 | logger = logging.getLogger(__name__) |
42 | 49 |
|
43 | 50 |
|
@@ -249,6 +256,12 @@ def run( |
249 | 256 | """ |
250 | 257 | self._current_job_name = self._generate_current_job_name(job_name=job_name) |
251 | 258 |
|
| 259 | + if is_pipeline_variable(submit_app): |
| 260 | + raise ValueError( |
| 261 | + "submit_app argument has to be a valid S3 URI or local file path " |
| 262 | + + "rather than a pipeline variable" |
| 263 | + ) |
| 264 | + |
252 | 265 | return super().run( |
253 | 266 | submit_app, |
254 | 267 | inputs, |
@@ -437,9 +450,14 @@ def _stage_submit_deps(self, submit_deps, input_channel_name): |
437 | 450 |
|
438 | 451 | use_input_channel = False |
439 | 452 | spark_opt_s3_uris = [] |
| 453 | + spark_opt_s3_uris_has_pipeline_var = False |
440 | 454 |
|
441 | 455 | with tempfile.TemporaryDirectory() as tmpdir: |
442 | 456 | for dep_path in submit_deps: |
| 457 | + if is_pipeline_variable(dep_path): |
| 458 | + spark_opt_s3_uris.append(dep_path) |
| 459 | + spark_opt_s3_uris_has_pipeline_var = True |
| 460 | + continue |
443 | 461 | dep_url = urlparse(dep_path) |
444 | 462 | # S3 URIs are included as-is in the spark-submit argument |
445 | 463 | if dep_url.scheme in ["s3", "s3a"]: |
@@ -482,11 +500,19 @@ def _stage_submit_deps(self, submit_deps, input_channel_name): |
482 | 500 | destination=f"{self._conf_container_base_path}{input_channel_name}", |
483 | 501 | input_name=input_channel_name, |
484 | 502 | ) |
485 | | - spark_opt = ",".join(spark_opt_s3_uris + [input_channel.destination]) |
| 503 | + spark_opt = ( |
| 504 | + Join(on=",", values=spark_opt_s3_uris + [input_channel.destination]) |
| 505 | + if spark_opt_s3_uris_has_pipeline_var |
| 506 | + else ",".join(spark_opt_s3_uris + [input_channel.destination]) |
| 507 | + ) |
486 | 508 | # If no local files were uploaded, form the spark-submit option from a list of S3 URIs |
487 | 509 | else: |
488 | 510 | input_channel = None |
489 | | - spark_opt = ",".join(spark_opt_s3_uris) |
| 511 | + spark_opt = ( |
| 512 | + Join(on=",", values=spark_opt_s3_uris) |
| 513 | + if spark_opt_s3_uris_has_pipeline_var |
| 514 | + else ",".join(spark_opt_s3_uris) |
| 515 | + ) |
490 | 516 |
|
491 | 517 | return input_channel, spark_opt |
492 | 518 |
|
@@ -592,6 +618,9 @@ def _validate_s3_uri(self, spark_output_s3_path): |
592 | 618 | Args: |
593 | 619 | spark_output_s3_path (str): The URI of the Spark output S3 Path. |
594 | 620 | """ |
| 621 | + if is_pipeline_variable(spark_output_s3_path): |
| 622 | + return |
| 623 | + |
595 | 624 | if urlparse(spark_output_s3_path).scheme != "s3": |
596 | 625 | raise ValueError( |
597 | 626 | f"Invalid s3 path: {spark_output_s3_path}. Please enter something like " |
@@ -650,22 +679,22 @@ class PySparkProcessor(_SparkProcessorBase): |
650 | 679 |
|
651 | 680 | def __init__( |
652 | 681 | self, |
653 | | - role, |
654 | | - instance_type, |
655 | | - instance_count, |
656 | | - framework_version=None, |
657 | | - py_version=None, |
658 | | - container_version=None, |
659 | | - image_uri=None, |
660 | | - volume_size_in_gb=30, |
661 | | - volume_kms_key=None, |
662 | | - output_kms_key=None, |
663 | | - max_runtime_in_seconds=None, |
664 | | - base_job_name=None, |
665 | | - sagemaker_session=None, |
666 | | - env=None, |
667 | | - tags=None, |
668 | | - network_config=None, |
| 682 | + role: str, |
| 683 | + instance_type: Union[int, PipelineVariable], |
| 684 | + instance_count: Union[str, PipelineVariable], |
| 685 | + framework_version: Optional[str] = None, |
| 686 | + py_version: Optional[str] = None, |
| 687 | + container_version: Optional[str] = None, |
| 688 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 689 | + volume_size_in_gb: Union[int, PipelineVariable] = 30, |
| 690 | + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 691 | + output_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 692 | + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, |
| 693 | + base_job_name: Optional[str] = None, |
| 694 | + sagemaker_session: Optional[Session] = None, |
| 695 | + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 696 | + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
| 697 | + network_config: Optional[NetworkConfig] = None, |
669 | 698 | ): |
670 | 699 | """Initialize an ``PySparkProcessor`` instance. |
671 | 700 |
|
@@ -795,20 +824,20 @@ def get_run_args( |
795 | 824 |
|
796 | 825 | def run( |
797 | 826 | self, |
798 | | - submit_app, |
799 | | - submit_py_files=None, |
800 | | - submit_jars=None, |
801 | | - submit_files=None, |
802 | | - inputs=None, |
803 | | - outputs=None, |
804 | | - arguments=None, |
805 | | - wait=True, |
806 | | - logs=True, |
807 | | - job_name=None, |
808 | | - experiment_config=None, |
809 | | - configuration=None, |
810 | | - spark_event_logs_s3_uri=None, |
811 | | - kms_key=None, |
| 827 | + submit_app: str, |
| 828 | + submit_py_files: Optional[List[Union[str, PipelineVariable]]] = None, |
| 829 | + submit_jars: Optional[List[Union[str, PipelineVariable]]] = None, |
| 830 | + submit_files: Optional[List[Union[str, PipelineVariable]]] = None, |
| 831 | + inputs: Optional[List[ProcessingInput]] = None, |
| 832 | + outputs: Optional[List[ProcessingOutput]] = None, |
| 833 | + arguments: Optional[List[Union[str, PipelineVariable]]] = None, |
| 834 | + wait: bool = True, |
| 835 | + logs: bool = True, |
| 836 | + job_name: Optional[str] = None, |
| 837 | + experiment_config: Optional[Dict[str, str]] = None, |
| 838 | + configuration: Optional[Union[List[Dict], Dict]] = None, |
| 839 | + spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None, |
| 840 | + kms_key: Optional[str] = None, |
812 | 841 | ): |
813 | 842 | """Runs a processing job. |
814 | 843 |
|
@@ -907,22 +936,22 @@ class SparkJarProcessor(_SparkProcessorBase): |
907 | 936 |
|
908 | 937 | def __init__( |
909 | 938 | self, |
910 | | - role, |
911 | | - instance_type, |
912 | | - instance_count, |
913 | | - framework_version=None, |
914 | | - py_version=None, |
915 | | - container_version=None, |
916 | | - image_uri=None, |
917 | | - volume_size_in_gb=30, |
918 | | - volume_kms_key=None, |
919 | | - output_kms_key=None, |
920 | | - max_runtime_in_seconds=None, |
921 | | - base_job_name=None, |
922 | | - sagemaker_session=None, |
923 | | - env=None, |
924 | | - tags=None, |
925 | | - network_config=None, |
| 939 | + role: str, |
| 940 | + instance_type: Union[int, PipelineVariable], |
| 941 | + instance_count: Union[str, PipelineVariable], |
| 942 | + framework_version: Optional[str] = None, |
| 943 | + py_version: Optional[str] = None, |
| 944 | + container_version: Optional[str] = None, |
| 945 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 946 | + volume_size_in_gb: Union[int, PipelineVariable] = 30, |
| 947 | + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 948 | + output_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 949 | + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, |
| 950 | + base_job_name: Optional[str] = None, |
| 951 | + sagemaker_session: Optional[Session] = None, |
| 952 | + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 953 | + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
| 954 | + network_config: Optional[NetworkConfig] = None, |
926 | 955 | ): |
927 | 956 | """Initialize a ``SparkJarProcessor`` instance. |
928 | 957 |
|
@@ -1052,20 +1081,20 @@ def get_run_args( |
1052 | 1081 |
|
1053 | 1082 | def run( |
1054 | 1083 | self, |
1055 | | - submit_app, |
1056 | | - submit_class=None, |
1057 | | - submit_jars=None, |
1058 | | - submit_files=None, |
1059 | | - inputs=None, |
1060 | | - outputs=None, |
1061 | | - arguments=None, |
1062 | | - wait=True, |
1063 | | - logs=True, |
1064 | | - job_name=None, |
1065 | | - experiment_config=None, |
1066 | | - configuration=None, |
1067 | | - spark_event_logs_s3_uri=None, |
1068 | | - kms_key=None, |
| 1084 | + submit_app: str, |
| 1085 | + submit_class: Union[str, PipelineVariable], |
| 1086 | + submit_jars: Optional[List[Union[str, PipelineVariable]]] = None, |
| 1087 | + submit_files: Optional[List[Union[str, PipelineVariable]]] = None, |
| 1088 | + inputs: Optional[List[ProcessingInput]] = None, |
| 1089 | + outputs: Optional[List[ProcessingOutput]] = None, |
| 1090 | + arguments: Optional[List[Union[str, PipelineVariable]]] = None, |
| 1091 | + wait: bool = True, |
| 1092 | + logs: bool = True, |
| 1093 | + job_name: Optional[str] = None, |
| 1094 | + experiment_config: Optional[Dict[str, str]] = None, |
| 1095 | + configuration: Optional[Union[List[Dict], Dict]] = None, |
| 1096 | + spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None, |
| 1097 | + kms_key: Optional[str] = None, |
1069 | 1098 | ): |
1070 | 1099 | """Runs a processing job. |
1071 | 1100 |
|
|
0 commit comments