Skip to content
14 changes: 14 additions & 0 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,20 @@ def deploy(
f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models."
)

# No resources given to deploy() but present 'resources' key in deploy_kwargs means default
# JumpStart resource requirements are being used
if hasattr(self, "_is_sharded_model") and not resources and deploy_kwargs.resources:
if (
self._is_sharded_model
and deploy_kwargs.resources.num_cpus
and deploy_kwargs.resources.num_cpus > 0
):
JUMPSTART_LOGGER.warning(
"NumOfCpuCoresRequired should be 0 for the best experience with SageMaker Fast "
"Model Loading. Overriding the requested `num_cpus` to 0."
)
deploy_kwargs.resources.num_cpus = 0

self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources(
self.additional_model_data_sources,
deploy_kwargs.model_access_configs,
Expand Down
29 changes: 18 additions & 11 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,18 +1600,25 @@ def deploy(
if self._base_name is not None:
self._base_name = "-".join((self._base_name, compiled_model_suffix))

if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
logging.warning(
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
)
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
if self._is_sharded_model:
if endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
logging.warning(
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
)
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED

if self._is_sharded_model and self._enable_network_isolation:
raise ValueError(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access."
)
if self._enable_network_isolation:
raise ValueError(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access."
)

if resources and resources.num_cpus and resources.num_cpus > 0:
logger.warning(
"NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker "
"Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`."
)

# Support multiple models on same endpoint
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/sagemaker/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,3 +1482,47 @@ def test_model_source(
)

assert model_1._get_model_uri() == "s3://tmybuckaet"


@patch("sagemaker.utils.repack_model")
@patch("sagemaker.fw_utils.tar_and_upload_dir")
def test_deploy_sharded_model_with_cpus_requested_raises_warning(
repack_model, tar_and_upload_dir, sagemaker_session
):
framework_model_classes_to_kwargs = {
HuggingFaceModel: {
"pytorch_version": "1.7.1",
"py_version": "py36",
"transformers_version": "4.6.1",
},
}

sagemaker_session.settings = SessionSettings(include_jumpstart_tags=False)

source_dir = "s3://blah/blah/blah"
for framework_model_class, kwargs in framework_model_classes_to_kwargs.items():
test_sharded_model = framework_model_class(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
model_data=source_dir,
**kwargs,
)
test_sharded_model._is_sharded_model = True
from unittest import mock

with mock.patch("sagemaker.model.logger") as mock_logger:
mock_logger.warning.reset_mock()
test_sharded_model.deploy(
instance_type="ml.m2.xlarge",
initial_instance_count=INSTANCE_COUNT,
endpoint_type=EndpointType.MODEL_BASED,
resources=ResourceRequirements(
requests={"num_accelerators": 1, "memory": 8192, "copies": 1, "num_cpus": 1},
limits={},
),
)
mock_logger.warning.assert_called_once_with(
"NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker "
"Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`."
)
Loading