Skip to content

Commit eed613f

Browse files
committed
Properly expose entrypoint for Processor subclasses
1 parent 7b865f5 commit eed613f

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

src/sagemaker/processing.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ def __init__(
509509
command: List[str] = None,
510510
instance_count: Union[int, PipelineVariable] = None,
511511
instance_type: Union[str, PipelineVariable] = None,
512+
entrypoint: Optional[List[Union[str, PipelineVariable]]] = None,
512513
volume_size_in_gb: Union[int, PipelineVariable] = 30,
513514
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
514515
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
@@ -537,6 +538,9 @@ def __init__(
537538
a processing job with.
538539
instance_type (str or PipelineVariable): The type of EC2 instance to use for
539540
processing, for example, 'ml.c4.xlarge'.
541+
entrypoint (list[str] or list[PipelineVariable]): The entrypoint for the
542+
processing job (default: None). This is in the form of a list of strings
543+
that make a command.
540544
volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume
541545
to use for storing data during processing (default: 30).
542546
volume_kms_key (str or PipelineVariable): A KMS key for the processing
@@ -572,6 +576,7 @@ def __init__(
572576
image_uri=image_uri,
573577
instance_count=instance_count,
574578
instance_type=instance_type,
579+
entrypoint=entrypoint,
575580
volume_size_in_gb=volume_size_in_gb,
576581
volume_kms_key=volume_kms_key,
577582
output_kms_key=output_kms_key,
@@ -845,14 +850,16 @@ def _set_entrypoint(self, command, user_script_name):
845850
Args:
846851
user_script_name (str): A filename with an extension.
847852
"""
848-
user_script_location = str(
849-
pathlib.PurePosixPath(
850-
self._CODE_CONTAINER_BASE_PATH,
851-
self._CODE_CONTAINER_INPUT_NAME,
852-
user_script_name,
853+
# Only set entrypoint if user hasn't provided one
854+
if self.entrypoint is None:
855+
user_script_location = str(
856+
pathlib.PurePosixPath(
857+
self._CODE_CONTAINER_BASE_PATH,
858+
self._CODE_CONTAINER_INPUT_NAME,
859+
user_script_name,
860+
)
853861
)
854-
)
855-
self.entrypoint = command + [user_script_location]
862+
self.entrypoint = command + [user_script_location]
856863

857864

858865
class ProcessingJob(_Job):
@@ -1434,6 +1441,7 @@ def __init__(
14341441
py_version: str = "py3",
14351442
image_uri: Optional[Union[str, PipelineVariable]] = None,
14361443
command: Optional[List[str]] = None,
1444+
entrypoint: Optional[List[Union[str, PipelineVariable]]] = None,
14371445
volume_size_in_gb: Union[int, PipelineVariable] = 30,
14381446
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
14391447
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
@@ -1471,6 +1479,9 @@ def __init__(
14711479
command ([str]): The command to run, along with any command-line flags
14721480
to *precede* the ```code script```. Example: ["python3", "-v"]. If not
14731481
provided, ["python"] will be chosen (default: None).
1482+
entrypoint (list[str] or list[PipelineVariable]): The entrypoint for the
1483+
processing job (default: None). This is in the form of a list of strings
1484+
that make a command.
14741485
volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume
14751486
to use for storing data during processing (default: 30).
14761487
volume_kms_key (str or PipelineVariable): A KMS key for the processing volume
@@ -1523,6 +1534,7 @@ def __init__(
15231534
command=command,
15241535
instance_count=instance_count,
15251536
instance_type=instance_type,
1537+
entrypoint=entrypoint,
15261538
volume_size_in_gb=volume_size_in_gb,
15271539
volume_kms_key=volume_kms_key,
15281540
output_kms_key=output_kms_key,
@@ -2001,13 +2013,14 @@ def _set_entrypoint(self, command, user_script_name):
20012013
command ([str]): Ignored in favor of self.framework_entrypoint_command
20022014
user_script_name (str): A filename with an extension.
20032015
"""
2004-
2005-
user_script_location = str(
2006-
pathlib.PurePosixPath(
2007-
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name
2016+
# Only set entrypoint if user hasn't provided one
2017+
if self.entrypoint is None:
2018+
user_script_location = str(
2019+
pathlib.PurePosixPath(
2020+
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name
2021+
)
20082022
)
2009-
)
2010-
self.entrypoint = self.framework_entrypoint_command + [user_script_location]
2023+
self.entrypoint = self.framework_entrypoint_command + [user_script_location]
20112024

20122025
def _create_and_upload_runproc(
20132026
self, user_script, kms_key, entrypoint_s3_uri, codeartifact_repo_arn=None

src/sagemaker/pytorch/processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
py_version: str = "py3", # New kwarg
4242
image_uri: Optional[Union[str, PipelineVariable]] = None,
4343
command: Optional[List[str]] = None,
44+
entrypoint: Optional[List[Union[str, PipelineVariable]]] = None,
4445
volume_size_in_gb: Union[int, PipelineVariable] = 30,
4546
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
4647
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
@@ -74,6 +75,7 @@ def __init__(
7475
py_version,
7576
image_uri,
7677
command,
78+
entrypoint,
7779
volume_size_in_gb,
7880
volume_kms_key,
7981
output_kms_key,

0 commit comments

Comments
 (0)