Skip to content

Commit 9a779fe

Browse files
Joseph Zhanggwang111
authored andcommitted
Disable network isolation if using sharded models.
1 parent 3e97708 commit 9a779fe

File tree

4 files changed

+17
-3
lines changed

4 files changed

+17
-3
lines changed

src/sagemaker/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,6 +1607,12 @@ def deploy(
16071607
)
16081608
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
16091609

1610+
if self._is_sharded_model and self._enable_network_isolation:
1611+
raise ValueError(
1612+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1613+
"Loading of model requires network access."
1614+
)
1615+
16101616
# Support multiple models on same endpoint
16111617
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
16121618
if endpoint_name:

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,14 @@ def _optimize_for_jumpstart(
795795
optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
796796
if optimization_env_vars:
797797
self.pysdk_model.env.update(optimization_env_vars)
798+
799+
if sharding_config and self.pysdk_model._enable_network_isolation:
800+
logger.warning(
801+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
802+
"Loading of model requires network access. Setting it to False."
803+
)
804+
self.pysdk_model._enable_network_isolation = False
805+
798806
if quantization_config or sharding_config or is_compilation:
799807
return create_optimization_job_args
800808
return None

src/sagemaker/serve/builder/model_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,8 +1276,8 @@ def _model_builder_optimize_wrapper(
12761276
):
12771277
raise ValueError(
12781278
(
1279-
"OPTION_TENSOR_PARALLEL_DEGREE is required "
1280-
"environment variable with Sharding config."
1279+
"OPTION_TENSOR_PARALLEL_DEGREE is a required "
1280+
"environment variable with sharding config."
12811281
)
12821282
)
12831283

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2701,7 +2701,7 @@ def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting):
27012701

27022702
self.assertRaisesRegex(
27032703
ValueError,
2704-
"OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.",
2704+
"OPTION_TENSOR_PARALLEL_DEGREE is a required environment variable with sharding config.",
27052705
lambda: model_builder.optimize(
27062706
instance_type="ml.g5.24xlarge",
27072707
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},

0 commit comments

Comments
 (0)