@@ -171,6 +171,39 @@ def __init__(
171171 network_config = network_config ,
172172 )
173173
174+ def get_run_args (
175+ self ,
176+ code ,
177+ inputs = None ,
178+ outputs = None ,
179+ arguments = None ,
180+ ):
181+ """Returns a RunArgs object.
182+
183+ For processors (:class:`~sagemaker.spark.processing.PySparkProcessor`,
184+ :class:`~sagemaker.spark.processing.SparkJar`) that have special
185+ run() arguments, this object contains the normalized arguments for passing to
186+ :class:`~sagemaker.workflow.steps.ProcessingStep`.
187+
188+ Args:
189+ code (str): This can be an S3 URI or a local path to a file with the framework
190+ script to run.
191+ inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
192+ the processing job. These must be provided as
193+ :class:`~sagemaker.processing.ProcessingInput` objects (default: None).
194+ outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
195+ the processing job. These can be specified as either path strings or
196+ :class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
197+ arguments (list[str]): A list of string arguments to be passed to a
198+ processing job (default: None).
199+ """
200+ return super ().get_run_args (
201+ code = code ,
202+ inputs = inputs ,
203+ outputs = outputs ,
204+ arguments = arguments ,
205+ )
206+
174207 def run (
175208 self ,
176209 submit_app ,
@@ -685,6 +718,73 @@ def __init__(
685718 network_config = network_config ,
686719 )
687720
721+ def get_run_args (
722+ self ,
723+ submit_app ,
724+ submit_py_files = None ,
725+ submit_jars = None ,
726+ submit_files = None ,
727+ inputs = None ,
728+ outputs = None ,
729+ arguments = None ,
730+ job_name = None ,
731+ configuration = None ,
732+ spark_event_logs_s3_uri = None ,
733+ ):
734+ """Returns a RunArgs object.
735+
736+ This object contains the normalized inputs, outputs
737+ and arguments needed when using a ``PySparkProcessor``
738+ in a :class:`~sagemaker.workflow.steps.ProcessingStep`.
739+
740+ Args:
741+ submit_app (str): Path (local or S3) to Python file to submit to Spark
742+ as the primary application. This is translated to the `code`
743+ property on the returned `RunArgs` object.
744+ submit_py_files (list[str]): List of paths (local or S3) to provide for
745+ `spark-submit --py-files` option
746+ submit_jars (list[str]): List of paths (local or S3) to provide for
747+ `spark-submit --jars` option
748+ submit_files (list[str]): List of paths (local or S3) to provide for
749+ `spark-submit --files` option
750+ inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
751+ the processing job. These must be provided as
752+ :class:`~sagemaker.processing.ProcessingInput` objects (default: None).
753+ outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
754+ the processing job. These can be specified as either path strings or
755+ :class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
756+ arguments (list[str]): A list of string arguments to be passed to a
757+ processing job (default: None).
758+ job_name (str): Processing job name. If not specified, the processor generates
759+ a default job name, based on the base job name and current timestamp.
760+ configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive.
761+ List or dictionary of EMR-style classifications.
762+ https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
763+ spark_event_logs_s3_uri (str): S3 path where spark application events will
764+ be published to.
765+ """
766+ self ._current_job_name = self ._generate_current_job_name (job_name = job_name )
767+
768+ if not submit_app :
769+ raise ValueError ("submit_app is required" )
770+
771+ extended_inputs , extended_outputs = self ._extend_processing_args (
772+ inputs = inputs ,
773+ outputs = outputs ,
774+ submit_py_files = submit_py_files ,
775+ submit_jars = submit_jars ,
776+ submit_files = submit_files ,
777+ configuration = configuration ,
778+ spark_event_logs_s3_uri = spark_event_logs_s3_uri ,
779+ )
780+
781+ return super ().get_run_args (
782+ code = submit_app ,
783+ inputs = extended_inputs ,
784+ outputs = extended_outputs ,
785+ arguments = arguments ,
786+ )
787+
688788 def run (
689789 self ,
690790 submit_app ,
@@ -738,14 +838,13 @@ def run(
738838 user code file (default: None).
739839 """
740840 self ._current_job_name = self ._generate_current_job_name (job_name = job_name )
741- self .command = [_SparkProcessorBase ._default_command ]
742841
743842 if not submit_app :
744843 raise ValueError ("submit_app is required" )
745844
746845 extended_inputs , extended_outputs = self ._extend_processing_args (
747- inputs ,
748- outputs ,
846+ inputs = inputs ,
847+ outputs = outputs ,
749848 submit_py_files = submit_py_files ,
750849 submit_jars = submit_jars ,
751850 submit_files = submit_files ,
@@ -762,6 +861,7 @@ def run(
762861 logs = logs ,
763862 job_name = self ._current_job_name ,
764863 experiment_config = experiment_config ,
864+ kms_key = kms_key ,
765865 )
766866
767867 def _extend_processing_args (self , inputs , outputs , ** kwargs ):
@@ -772,6 +872,7 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
772872 outputs: Processing outputs.
773873 kwargs: Additional keyword arguments passed to `super()`.
774874 """
875+ self .command = [_SparkProcessorBase ._default_command ]
775876 extended_inputs = self ._handle_script_dependencies (
776877 inputs , kwargs .get ("submit_py_files" ), FileType .PYTHON
777878 )
@@ -866,6 +967,73 @@ def __init__(
866967 network_config = network_config ,
867968 )
868969
970+ def get_run_args (
971+ self ,
972+ submit_app ,
973+ submit_class = None ,
974+ submit_jars = None ,
975+ submit_files = None ,
976+ inputs = None ,
977+ outputs = None ,
978+ arguments = None ,
979+ job_name = None ,
980+ configuration = None ,
981+ spark_event_logs_s3_uri = None ,
982+ ):
983+ """Returns a RunArgs object.
984+
985+ This object contains the normalized inputs, outputs
986+ and arguments needed when using a ``SparkJarProcessor``
987+ in a :class:`~sagemaker.workflow.steps.ProcessingStep`.
988+
989+ Args:
990+ submit_app (str): Path (local or S3) to Python file to submit to Spark
991+ as the primary application. This is translated to the `code`
992+ property on the returned `RunArgs` object
993+ submit_class (str): Java class reference to submit to Spark as the primary
994+ application
995+ submit_jars (list[str]): List of paths (local or S3) to provide for
996+ `spark-submit --jars` option
997+ submit_files (list[str]): List of paths (local or S3) to provide for
998+ `spark-submit --files` option
999+ inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
1000+ the processing job. These must be provided as
1001+ :class:`~sagemaker.processing.ProcessingInput` objects (default: None).
1002+ outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
1003+ the processing job. These can be specified as either path strings or
1004+ :class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
1005+ arguments (list[str]): A list of string arguments to be passed to a
1006+ processing job (default: None).
1007+ job_name (str): Processing job name. If not specified, the processor generates
1008+ a default job name, based on the base job name and current timestamp.
1009+ configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive.
1010+ List or dictionary of EMR-style classifications.
1011+ https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
1012+ spark_event_logs_s3_uri (str): S3 path where spark application events will
1013+ be published to.
1014+ """
1015+ self ._current_job_name = self ._generate_current_job_name (job_name = job_name )
1016+
1017+ if not submit_app :
1018+ raise ValueError ("submit_app is required" )
1019+
1020+ extended_inputs , extended_outputs = self ._extend_processing_args (
1021+ inputs = inputs ,
1022+ outputs = outputs ,
1023+ submit_class = submit_class ,
1024+ submit_jars = submit_jars ,
1025+ submit_files = submit_files ,
1026+ configuration = configuration ,
1027+ spark_event_logs_s3_uri = spark_event_logs_s3_uri ,
1028+ )
1029+
1030+ return super ().get_run_args (
1031+ code = submit_app ,
1032+ inputs = extended_inputs ,
1033+ outputs = extended_outputs ,
1034+ arguments = arguments ,
1035+ )
1036+
8691037 def run (
8701038 self ,
8711039 submit_app ,
@@ -919,14 +1087,13 @@ def run(
9191087 user code file (default: None).
9201088 """
9211089 self ._current_job_name = self ._generate_current_job_name (job_name = job_name )
922- self .command = [_SparkProcessorBase ._default_command ]
9231090
9241091 if not submit_app :
9251092 raise ValueError ("submit_app is required" )
9261093
9271094 extended_inputs , extended_outputs = self ._extend_processing_args (
928- inputs ,
929- outputs ,
1095+ inputs = inputs ,
1096+ outputs = outputs ,
9301097 submit_class = submit_class ,
9311098 submit_jars = submit_jars ,
9321099 submit_files = submit_files ,
@@ -947,6 +1114,7 @@ def run(
9471114 )
9481115
9491116 def _extend_processing_args (self , inputs , outputs , ** kwargs ):
1117+ self .command = [_SparkProcessorBase ._default_command ]
9501118 if kwargs .get ("submit_class" ):
9511119 self .command .extend (["--class" , kwargs .get ("submit_class" )])
9521120 else :
0 commit comments