2424from textwrap import dedent
2525from typing import Dict , List , Optional , Union
2626from copy import copy
27+ import re
2728
2829import attr
2930
@@ -1658,6 +1659,7 @@ def run( # type: ignore[override]
16581659 job_name : Optional [str ] = None ,
16591660 experiment_config : Optional [Dict [str , str ]] = None ,
16601661 kms_key : Optional [str ] = None ,
1662+ codeartifact_repo_arn : Optional [str ] = None ,
16611663 ):
16621664 """Runs a processing job.
16631665
@@ -1758,12 +1760,21 @@ def run( # type: ignore[override]
17581760 However, the value of `TrialComponentDisplayName` is honored for display in Studio.
17591761 kms_key (str): The ARN of the KMS key that is used to encrypt the
17601762 user code file (default: None).
1763+ codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be
1764+ logged into before installing dependencies (default: None).
17611765 Returns:
17621766 None or pipeline step arguments in case the Processor instance is built with
17631767 :class:`~sagemaker.workflow.pipeline_context.PipelineSession`
17641768 """
17651769 s3_runproc_sh , inputs , job_name = self ._pack_and_upload_code (
1766- code , source_dir , dependencies , git_config , job_name , inputs , kms_key
1770+ code ,
1771+ source_dir ,
1772+ dependencies ,
1773+ git_config ,
1774+ job_name ,
1775+ inputs ,
1776+ kms_key ,
1777+ codeartifact_repo_arn ,
17671778 )
17681779
17691780 # Submit a processing job.
@@ -1780,7 +1791,15 @@ def run( # type: ignore[override]
17801791 )
17811792
17821793 def _pack_and_upload_code (
1783- self , code , source_dir , dependencies , git_config , job_name , inputs , kms_key = None
1794+ self ,
1795+ code ,
1796+ source_dir ,
1797+ dependencies ,
1798+ git_config ,
1799+ job_name ,
1800+ inputs ,
1801+ kms_key = None ,
1802+ codeartifact_repo_arn = None ,
17841803 ):
17851804 """Pack local code bundle and upload to Amazon S3."""
17861805 if code .startswith ("s3://" ):
@@ -1821,12 +1840,53 @@ def _pack_and_upload_code(
18211840 script = estimator .uploaded_code .script_name
18221841 evaluated_kms_key = kms_key if kms_key else self .output_kms_key
18231842 s3_runproc_sh = self ._create_and_upload_runproc (
1824- script , evaluated_kms_key , entrypoint_s3_uri
1843+ script , evaluated_kms_key , entrypoint_s3_uri , codeartifact_repo_arn
18251844 )
18261845
18271846 return s3_runproc_sh , inputs , job_name
18281847
1829- def _generate_framework_script (self , user_script : str ) -> str :
1848+ def _get_codeartifact_command (self , codeartifact_repo_arn : str ) -> str :
1849+ """Build an AWS CLI CodeArtifact command to configure pip.
1850+
1851+ The codeartifact_repo_arn property must follow the form
1852+ # `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}`
1853+ https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html
1854+ https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies
1855+
1856+ Args:
1857+ codeartifact_repo_arn: arn of the codeartifact repository
1858+ Returns:
1859+ codeartifact command string
1860+ """
1861+
1862+ arn_regex = (
1863+ "arn:(?P<partition>[^:]+):codeartifact:(?P<region>[^:]+):(?P<account>[^:]+)"
1864+ ":repository/(?P<domain>[^/]+)/(?P<repository>.+)"
1865+ )
1866+ m = re .match (arn_regex , codeartifact_repo_arn )
1867+ if not m :
1868+ raise ValueError ("invalid CodeArtifact repository arn {}" .format (codeartifact_repo_arn ))
1869+ domain = m .group ("domain" )
1870+ owner = m .group ("account" )
1871+ repository = m .group ("repository" )
1872+ region = m .group ("region" )
1873+
1874+ logger .info (
1875+ "configuring pip to use codeartifact "
1876+ "(domain: %s, domain owner: %s, repository: %s, region: %s)" ,
1877+ domain ,
1878+ owner ,
1879+ repository ,
1880+ region ,
1881+ )
1882+
1883+ return "aws codeartifact login --tool pip --domain {} --domain-owner {} --repository {} --region {}" .format ( # noqa: E501 pylint: disable=line-too-long
1884+ domain , owner , repository , region
1885+ )
1886+
1887+ def _generate_framework_script (
1888+ self , user_script : str , codeartifact_repo_arn : str = None
1889+ ) -> str :
18301890 """Generate the framework entrypoint file (as text) for a processing job.
18311891
18321892 This script implements the "framework" functionality for setting up your code:
@@ -1837,7 +1897,16 @@ def _generate_framework_script(self, user_script: str) -> str:
18371897 Args:
18381898 user_script (str): Relative path to ```code``` in the source bundle
18391899 - e.g. 'process.py'.
1900+ codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be
1901+ logged into before installing dependencies (default: None).
18401902 """
1903+ if codeartifact_repo_arn :
1904+ codeartifact_login_command = self ._get_codeartifact_command (codeartifact_repo_arn )
1905+ else :
1906+ codeartifact_login_command = (
1907+ "echo 'CodeArtifact repository not specified. Skipping login.'"
1908+ )
1909+
18411910 return dedent (
18421911 """\
18431912 #!/bin/bash
@@ -1849,6 +1918,13 @@ def _generate_framework_script(self, user_script: str) -> str:
18491918 set -e
18501919
18511920 if [[ -f 'requirements.txt' ]]; then
1921+ # Optionally log into CodeArtifact
1922+ if ! hash aws 2>/dev/null; then
1923+ echo "AWS CLI is not installed. Skipping CodeArtifact login."
1924+ else
1925+ {codeartifact_login_command}
1926+ fi
1927+
18521928 # Some py3 containers has typing, which may breaks pip install
18531929 pip uninstall --yes typing
18541930
@@ -1858,6 +1934,7 @@ def _generate_framework_script(self, user_script: str) -> str:
18581934 {entry_point_command} {entry_point} "$@"
18591935 """
18601936 ).format (
1937+ codeartifact_login_command = codeartifact_login_command ,
18611938 entry_point_command = " " .join (self .command ),
18621939 entry_point = user_script ,
18631940 )
@@ -1933,7 +2010,9 @@ def _set_entrypoint(self, command, user_script_name):
19332010 )
19342011 self .entrypoint = self .framework_entrypoint_command + [user_script_location ]
19352012
1936- def _create_and_upload_runproc (self , user_script , kms_key , entrypoint_s3_uri ):
2013+ def _create_and_upload_runproc (
2014+ self , user_script , kms_key , entrypoint_s3_uri , codeartifact_repo_arn = None
2015+ ):
19372016 """Create runproc shell script and upload to S3 bucket.
19382017
19392018 If leveraging a pipeline session with optimized S3 artifact paths,
@@ -1949,7 +2028,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
19492028 from sagemaker .workflow .utilities import _pipeline_config , hash_object
19502029
19512030 if _pipeline_config and _pipeline_config .pipeline_name :
1952- runproc_file_str = self ._generate_framework_script (user_script )
2031+ runproc_file_str = self ._generate_framework_script (user_script , codeartifact_repo_arn )
19532032 runproc_file_hash = hash_object (runproc_file_str )
19542033 s3_uri = s3 .s3_path_join (
19552034 "s3://" ,
@@ -1968,7 +2047,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
19682047 )
19692048 else :
19702049 s3_runproc_sh = S3Uploader .upload_string_as_file_body (
1971- self ._generate_framework_script (user_script ),
2050+ self ._generate_framework_script (user_script , codeartifact_repo_arn ),
19722051 desired_s3_uri = entrypoint_s3_uri ,
19732052 kms_key = kms_key ,
19742053 sagemaker_session = self .sagemaker_session ,
0 commit comments