Skip to content

Commit 3e69cbe

Browse files
jiapinwbenieric
andauthored
Enable private docker registry support for ModelBuilder (#4399)
* Initial commit for image_config support in ModelBuilder * Add vpc_config to model builder * fix black format * fix flake8 * try to fix ut * try to fix ut * fix constants import in UT * fix ut * fix test_model_builder ut * fix test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mode * Add documentation * fix black-format, flake8 and unnecessary usecase with jumpstart models --------- Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent 2fff890 commit 3e69cbe

File tree

8 files changed

+84
-18
lines changed

8 files changed

+84
-18
lines changed

src/sagemaker/serve/builder/djl_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(self):
8383
self.mode = None
8484
self.model_server = None
8585
self.image_uri = None
86+
self.image_config = None
87+
self.vpc_config = None
8688
self._original_deploy = None
8789
self.secret_key = None
8890
self.engine = None
@@ -138,6 +140,8 @@ def _create_djl_model(self) -> Type[Model]:
138140
"source_dir": code_dir,
139141
"env": self.env_vars,
140142
"hf_hub_token": self.env_vars.get("HUGGING_FACE_HUB_TOKEN"),
143+
"image_config": self.image_config,
144+
"vpc_config": self.vpc_config,
141145
}
142146

143147
if self.engine == _DjlEngine.DEEPSPEED:

src/sagemaker/serve/builder/model_builder.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from sagemaker.serve.validations.check_image_and_hardware_type import (
5555
validate_image_uri_and_hardware,
5656
)
57+
from sagemaker.workflow.entities import PipelineVariable
5758
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
5859

5960
logger = logging.getLogger(__name__)
@@ -81,7 +82,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
8182
8283
* ``Mode.SAGEMAKER_ENDPOINT``: Launch on a SageMaker endpoint
8384
* ``Mode.LOCAL_CONTAINER``: Launch locally with a container
84-
8585
shared_libs (List[str]): Any shared libraries you want to bring into
8686
the model packaging.
8787
dependencies (Optional[Dict[str, Any]): The dependencies of the model
@@ -122,6 +122,15 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
122122
``invoke`` and ``load`` functions.
123123
image_uri (Optional[str]): The container image uri (which is derived from a
124124
SageMaker-based container).
125+
image_config (dict[str, str] or dict[str, PipelineVariable]): Specifies
126+
whether the image of model container is pulled from ECR, or private
127+
registry in your VPC. By default it is set to pull model container
128+
image from ECR. (default: None).
129+
vpc_config ( Optional[Dict[str, List[Union[str, PipelineVariable]]]]):
130+
The VpcConfig set on the model (default: None)
131+
* 'Subnets' (List[Union[str, PipelineVariable]]): List of subnet ids.
132+
* 'SecurityGroupIds' (List[Union[str, PipelineVariable]]]): List of security group
133+
ids.
125134
model_server (Optional[ModelServer]): The model server to which to deploy.
126135
You need to provide this argument when you specify an ``image_uri``
127136
in order for model builder to build the artifacts correctly (according
@@ -204,6 +213,23 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
204213
image_uri: Optional[str] = field(
205214
default=None, metadata={"help": "Define the container image uri"}
206215
)
216+
image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = field(
217+
default=None,
218+
metadata={
219+
"help": "Specifies whether the image of model container is pulled from ECR,"
220+
" or private registry in your VPC. By default it is set to pull model "
221+
"container image from ECR. (default: None)."
222+
},
223+
)
224+
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = field(
225+
default=None,
226+
metadata={
227+
"help": "The VpcConfig set on the model (default: None)."
228+
"* 'Subnets' (List[Union[str, PipelineVariable]]): List of subnet ids."
229+
"* ''SecurityGroupIds'' (List[Union[str, PipelineVariable]]): List of"
230+
" security group ids."
231+
},
232+
)
207233
model_server: Optional[ModelServer] = field(
208234
default=None, metadata={"help": "Define the model server to deploy to."}
209235
)
@@ -386,6 +412,8 @@ def _create_model(self):
386412
# TODO: we should create model as per the framework
387413
self.pysdk_model = Model(
388414
image_uri=self.image_uri,
415+
image_config=self.image_config,
416+
vpc_config=self.vpc_config,
389417
model_data=self.s3_upload_path,
390418
role=self.serve_settings.role_arn,
391419
env=self.env_vars,
@@ -543,15 +571,16 @@ def build(
543571
self,
544572
mode: Type[Mode] = None,
545573
role_arn: str = None,
546-
sagemaker_session: str = None,
574+
sagemaker_session: Optional[Session] = None,
547575
) -> Type[Model]:
548576
"""Create a deployable ``Model`` instance with ``ModelBuilder``.
549577
550578
Args:
551579
mode (Type[Mode], optional): The mode. Defaults to ``None``.
552580
role_arn (str, optional): The IAM role arn. Defaults to ``None``.
553-
sagemaker_session (str, optional): The SageMaker session to use
554-
for the execution. Defaults to ``None``.
581+
sagemaker_session (Optional[Session]): Session object which manages interactions
582+
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
583+
function creates one using the default AWS configuration chain.
555584
556585
Returns:
557586
Type[Model]: A deployable ``Model`` object.
@@ -562,10 +591,7 @@ def build(
562591
self.mode = mode
563592
if role_arn:
564593
self.role_arn = role_arn
565-
if sagemaker_session:
566-
self.sagemaker_session = sagemaker_session
567-
elif not self.sagemaker_session:
568-
self.sagemaker_session = Session()
594+
self.sagemaker_session = sagemaker_session or Session()
569595

570596
self.sagemaker_session.settings._local_download_dir = self.model_path
571597

@@ -607,7 +633,7 @@ def save(
607633
self,
608634
save_path: Optional[str] = None,
609635
s3_path: Optional[str] = None,
610-
sagemaker_session: Optional[str] = None,
636+
sagemaker_session: Optional[Session] = None,
611637
role_arn: Optional[str] = None,
612638
) -> Type[Model]:
613639
"""WARNING: This function is expremental and not intended for production use.
@@ -618,7 +644,7 @@ def save(
618644
save_path (Optional[str]): The path where you want to save resources.
619645
s3_path (Optional[str]): The path where you want to upload resources.
620646
"""
621-
self.sagemaker_session = sagemaker_session if sagemaker_session else Session()
647+
self.sagemaker_session = sagemaker_session or Session()
622648

623649
if role_arn:
624650
self.role_arn = role_arn

src/sagemaker/serve/builder/tgi_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def __init__(self):
7676
self.mode = None
7777
self.model_server = None
7878
self.image_uri = None
79+
self.image_config = None
80+
self.vpc_config = None
7981
self._original_deploy = None
8082
self.hf_model_config = None
8183
self._default_tensor_parallel_degree = None
@@ -134,6 +136,8 @@ def _create_tgi_model(self) -> Type[Model]:
134136

135137
pysdk_model = HuggingFaceModel(
136138
image_uri=self.image_uri,
139+
image_config=self.image_config,
140+
vpc_config=self.vpc_config,
137141
env=self.env_vars,
138142
role=self.role_arn,
139143
sagemaker_session=self.sagemaker_session,

src/sagemaker/serve/model_server/triton/triton_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,8 @@ def _auto_detect_image_for_triton(self):
413413
def _create_triton_model(self) -> Type[Model]:
414414
self.pysdk_model = Model(
415415
image_uri=self.image_uri,
416+
image_config=self.image_config,
417+
vpc_config=self.vpc_config,
416418
model_data=self.s3_upload_path,
417419
role=self.serve_settings.role_arn,
418420
env=self.env_vars,

tests/unit/sagemaker/serve/builder/test_djl_builder.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
LocalModelInvocationException,
3131
)
3232
from sagemaker.serve.utils.predictors import DjlLocalModePredictor
33+
from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG
3334

3435
mock_model_id = "TheBloke/Llama-2-7b-chat-fp16"
3536
mock_t5_model_id = "google/flan-t5-xxl"
@@ -113,6 +114,8 @@ def test_build_deploy_for_djl_local_container(
113114
schema_builder=mock_schema_builder,
114115
mode=Mode.LOCAL_CONTAINER,
115116
model_server=ModelServer.DJL_SERVING,
117+
image_config=MOCK_IMAGE_CONFIG,
118+
vpc_config=MOCK_VPC_CONFIG,
116119
)
117120

118121
builder._prepare_for_mode = MagicMock()
@@ -132,6 +135,8 @@ def test_build_deploy_for_djl_local_container(
132135
assert builder._default_max_new_tokens == 256
133136
assert builder.schema_builder.sample_input["parameters"]["max_new_tokens"] == 256
134137
assert builder.nb_instance_type == "ml.g5.24xlarge"
138+
assert model.image_config == MOCK_IMAGE_CONFIG
139+
assert model.vpc_config == MOCK_VPC_CONFIG
135140
assert "deepspeed" in builder.image_uri
136141

137142
builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
@@ -176,6 +181,8 @@ def test_build_for_djl_local_container_faster_transformer(
176181
schema_builder=mock_schema_builder,
177182
mode=Mode.LOCAL_CONTAINER,
178183
model_server=ModelServer.DJL_SERVING,
184+
image_config=MOCK_IMAGE_CONFIG,
185+
vpc_config=MOCK_VPC_CONFIG,
179186
)
180187
model = builder.build()
181188
builder.serve_settings.telemetry_opt_out = True
@@ -185,6 +192,8 @@ def test_build_for_djl_local_container_faster_transformer(
185192
model.generate_serving_properties()
186193
== mock_expected_fastertransformer_serving_properties
187194
)
195+
assert model.image_config == MOCK_IMAGE_CONFIG
196+
assert model.vpc_config == MOCK_VPC_CONFIG
188197
assert "fastertransformer" in builder.image_uri
189198

190199
@patch(
@@ -212,11 +221,15 @@ def test_build_for_djl_local_container_deepspeed(
212221
schema_builder=mock_schema_builder,
213222
mode=Mode.LOCAL_CONTAINER,
214223
model_server=ModelServer.DJL_SERVING,
224+
image_config=MOCK_IMAGE_CONFIG,
225+
vpc_config=MOCK_VPC_CONFIG,
215226
)
216227
model = builder.build()
217228
builder.serve_settings.telemetry_opt_out = True
218229

219230
assert isinstance(model, DeepSpeedModel)
231+
assert model.image_config == MOCK_IMAGE_CONFIG
232+
assert model.vpc_config == MOCK_VPC_CONFIG
220233
assert model.generate_serving_properties() == mock_expected_deepspeed_serving_properties
221234
assert "deepspeed" in builder.image_uri
222235

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sagemaker.serve.builder.model_builder import ModelBuilder
2020
from sagemaker.serve.mode.function_pointers import Mode
2121
from sagemaker.serve.utils.types import ModelServer
22+
from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG
2223

2324
schema_builder = MagicMock()
2425
mock_inference_spec = Mock()
@@ -187,8 +188,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
187188

188189
mock_model_obj = Mock()
189190
mock_sdk_model.side_effect = (
190-
lambda image_uri, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj
191+
lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501
191192
if image_uri == mock_image_uri
193+
and image_config == MOCK_IMAGE_CONFIG
194+
and vpc_config == MOCK_VPC_CONFIG
192195
and model_data == model_data
193196
and role == mock_role_arn
194197
and env == ENV_VARS
@@ -205,6 +208,8 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
205208
model=mock_fw_model,
206209
model_server=ModelServer.TORCHSERVE,
207210
image_uri=mock_image_uri,
211+
image_config=MOCK_IMAGE_CONFIG,
212+
vpc_config=MOCK_VPC_CONFIG,
208213
)
209214
build_result = builder.build(sagemaker_session=mock_session)
210215

@@ -286,7 +291,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
286291

287292
mock_model_obj = Mock()
288293
mock_sdk_model.side_effect = (
289-
lambda image_uri, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj
294+
lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501
290295
if image_uri == mock_1p_dlc_image_uri
291296
and model_data == model_data
292297
and role == mock_role_arn
@@ -391,7 +396,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
391396

392397
mock_model_obj = Mock()
393398
mock_sdk_model.side_effect = (
394-
lambda image_uri, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj
399+
lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501
395400
if image_uri == mock_image_uri
396401
and model_data == model_data
397402
and role == mock_role_arn
@@ -487,7 +492,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
487492

488493
mock_model_obj = Mock()
489494
mock_sdk_model.side_effect = (
490-
lambda image_uri, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj
495+
lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501
491496
if image_uri == mock_image_uri
492497
and model_data == model_data
493498
and role == mock_role_arn
@@ -591,7 +596,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
591596

592597
mock_model_obj = Mock()
593598
mock_sdk_model.side_effect = (
594-
lambda image_uri, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj
599+
lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501
595600
if image_uri == mock_image_uri
596601
and model_data == model_data
597602
and role == mock_role_arn
@@ -692,7 +697,7 @@ def test_build_happy_path_with_local_container_mode(
692697

693698
mock_model_obj = Mock()
694699
mock_sdk_model.side_effect = (
695-
lambda image_uri, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj
700+
lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501
696701
if image_uri == mock_image_uri
697702
and model_data is None
698703
and role == mock_role_arn
@@ -809,7 +814,7 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
809814

810815
mock_model_obj = Mock()
811816
mock_sdk_model.side_effect = (
812-
lambda image_uri, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj
817+
lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501
813818
if image_uri == mock_image_uri
814819
and model_data is None
815820
and role == mock_role_arn
@@ -951,7 +956,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
951956

952957
mock_model_obj = Mock()
953958
mock_sdk_model.side_effect = (
954-
lambda image_uri, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj
959+
lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: mock_model_obj # noqa E501
955960
if image_uri == mock_image_uri
956961
and model_data == model_data
957962
and role == mock_role_arn

tests/unit/sagemaker/serve/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
15+
16+
MOCK_IMAGE_CONFIG = {"RepositoryAccessMode": "Vpc"}
17+
MOCK_VPC_CONFIG = {"Subnets": ["subnet-1234"], "SecurityGroupIds": ["sg123"]}

tests/unit/sagemaker/serve/model_server/triton/test_triton_builder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from sagemaker.serve.mode.function_pointers import Mode
2121
import torch
2222

23+
from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG
24+
2325
TRITON_IMAGE = "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:23.02-py3"
2426
MODEL_PATH = "/path/to/working/dir"
2527
S3_UPLOAD_PATH = "s3://path/to/bucket"
@@ -48,6 +50,8 @@ class TritonBuilderTests(TestCase):
4850
def prepare_triton_builder_for_model(self, triton_builder: Triton) -> Triton:
4951
triton_builder.model = MOCK_PT_MODEL
5052
triton_builder.image_uri = TRITON_IMAGE
53+
triton_builder.image_config = MOCK_IMAGE_CONFIG
54+
triton_builder.vpc_config = MOCK_VPC_CONFIG
5155
triton_builder.mode = Mode.LOCAL_CONTAINER
5256
triton_builder.schema_builder = pt_schema_builder
5357
triton_builder.model_path = MODEL_PATH
@@ -90,6 +94,8 @@ def test_build_for_triton_pt(self, mock_detect_fw, mock_get_gpus, mock_path, moc
9094

9195
mock_model.assert_called_with(
9296
image_uri=TRITON_IMAGE,
97+
image_config=MOCK_IMAGE_CONFIG,
98+
vpc_config=MOCK_VPC_CONFIG,
9399
model_data=S3_UPLOAD_PATH,
94100
role=ROLE_ARN,
95101
env=ENV_VAR,
@@ -122,6 +128,8 @@ def test_build_for_triton_tf(self, mock_detect_fw, mock_get_gpus, mock_path, moc
122128

123129
mock_model.assert_called_with(
124130
image_uri=TRITON_IMAGE,
131+
image_config=MOCK_IMAGE_CONFIG,
132+
vpc_config=MOCK_VPC_CONFIG,
125133
model_data=S3_UPLOAD_PATH,
126134
role=ROLE_ARN,
127135
env=ENV_VAR,

0 commit comments

Comments
 (0)