Skip to content

Commit 7e8e237

Browse files
Joseph Zhanggwang111
authored andcommitted
Disable network isolation if using sharded models.
1 parent 3e97708 commit 7e8e237

File tree

5 files changed

+29
-19
lines changed

5 files changed

+29
-19
lines changed

src/sagemaker/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,6 +1607,12 @@ def deploy(
16071607
)
16081608
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
16091609

1610+
if self._is_sharded_model and self._enable_network_isolation:
1611+
raise ValueError(
1612+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1613+
"Loading of model requires network access."
1614+
)
1615+
16101616
# Support multiple models on same endpoint
16111617
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
16121618
if endpoint_name:

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,14 @@ def _optimize_for_jumpstart(
795795
optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
796796
if optimization_env_vars:
797797
self.pysdk_model.env.update(optimization_env_vars)
798+
799+
if sharding_config and self.pysdk_model._enable_network_isolation:
800+
logger.warning(
801+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
802+
"Loading of model requires network access. Setting it to False."
803+
)
804+
self.pysdk_model._enable_network_isolation = False
805+
798806
if quantization_config or sharding_config or is_compilation:
799807
return create_optimization_job_args
800808
return None

src/sagemaker/serve/builder/model_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,8 +1276,8 @@ def _model_builder_optimize_wrapper(
12761276
):
12771277
raise ValueError(
12781278
(
1279-
"OPTION_TENSOR_PARALLEL_DEGREE is required "
1280-
"environment variable with Sharding config."
1279+
"OPTION_TENSOR_PARALLEL_DEGREE is a required "
1280+
"environment variable with sharding config."
12811281
)
12821282
)
12831283

src/sagemaker/serve/validations/optimization.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,15 @@ def _validate_optimization_configuration(
211211
f"Optimizations that use {trt_compare_error} for GPU instances."
212212
)
213213
if str(trt_compare_error) == str(vllm_compare_error):
214-
joint_error_msg = f"""
215-
Optimization cannot be performed for the following reasons:
216-
- Optimizations that use {trt_compare_error} are not supported for GPU instances.
217-
"""
218-
else:
219-
joint_error_msg = f"""
220-
Optimization cannot be performed for the following reasons:
221-
- Optimizations that use {trt_compare_error} are not supported for GPU instances.
222-
- Optimizations that use {vllm_compare_error} are not supported for GPU instances.
223-
"""
214+
raise ValueError(
215+
(
216+
f"Optimizations that use {trt_compare_error} "
217+
"are not supported for GPU instances."
218+
)
219+
)
220+
joint_error_msg = f"""
221+
Optimization cannot be performed for the following reasons:
222+
- Optimizations that use {trt_compare_error} are not supported for GPU instances.
223+
- Optimizations that use {vllm_compare_error} are not supported for GPU instances.
224+
"""
224225
raise ValueError(textwrap.dedent(joint_error_msg))

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import textwrap
1615
from unittest.mock import MagicMock, patch, Mock, mock_open
1716

1817
import unittest
@@ -2701,7 +2700,7 @@ def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting):
27012700

27022701
self.assertRaisesRegex(
27032702
ValueError,
2704-
"OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.",
2703+
"OPTION_TENSOR_PARALLEL_DEGREE is a required environment variable with sharding config.",
27052704
lambda: model_builder.optimize(
27062705
instance_type="ml.g5.24xlarge",
27072706
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
@@ -2876,13 +2875,9 @@ def test_trt_and_vllm_configurations_throw_errors_for_rule_set(self):
28762875
)
28772876

28782877
# Invalid quantization technique
2879-
expected_quantization_error_message = """
2880-
Optimization cannot be performed for the following reasons:
2881-
- Optimizations that use Quantization:test are not supported for GPU instances.
2882-
"""
28832878
self.assertRaisesRegex(
28842879
ValueError,
2885-
textwrap.dedent(expected_quantization_error_message),
2880+
"Optimizations that use Quantization:test are not supported for GPU instances.",
28862881
lambda: _validate_optimization_configuration(
28872882
instance_type="ml.g5.24xlarge",
28882883
quantization_config={

0 commit comments

Comments
 (0)