diff --git a/template/v3/Dockerfile b/template/v3/Dockerfile index 644caf8aa..6fa2b9ab2 100644 --- a/template/v3/Dockerfile +++ b/template/v3/Dockerfile @@ -7,6 +7,7 @@ ARG PINNED_ENV_IN_FILENAME ARG ARG_BASED_ENV_IN_FILENAME ARG IMAGE_VERSION LABEL "org.amazon.sagemaker-distribution.image.version"=$IMAGE_VERSION +LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true ARG AMZN_BASE="/opt/amazon/sagemaker" ARG DB_ROOT_DIR="/opt/db" diff --git a/template/v3/dirs/etc/sagemaker-inference-server/utils/environment.py b/template/v3/dirs/etc/sagemaker-inference-server/utils/environment.py index 18870de27..be09e7ef3 100644 --- a/template/v3/dirs/etc/sagemaker-inference-server/utils/environment.py +++ b/template/v3/dirs/etc/sagemaker-inference-server/utils/environment.py @@ -5,8 +5,14 @@ from enum import Enum +class SageMakerPlatform(str, Enum): + """Simple enum to define environment variables injected by the SageMaker platform.""" + + PLATFORM_PORT = "SAGEMAKER_BIND_TO_PORT" + + class SageMakerInference(str, Enum): - """Simple enum to define the mapping between dictionary key and environement variable.""" + """Simple enum to define the mapping between dictionary key and environment variable.""" BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY" REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS" @@ -28,7 +34,7 @@ def __init__(self): SageMakerInference.CODE_DIRECTORY: os.getenv(SageMakerInference.CODE_DIRECTORY, None), SageMakerInference.CODE: os.getenv(SageMakerInference.CODE, "inference.handler"), SageMakerInference.LOG_LEVEL: os.getenv(SageMakerInference.LOG_LEVEL, 10), - SageMakerInference.PORT: 8080, + SageMakerInference.PORT: self._resolve_port(), } def __str__(self): @@ -57,3 +63,10 @@ def logging_level(self): @property def port(self): return self._environment_variables.get(SageMakerInference.PORT) + + def _resolve_port(self) -> int: + if os.getenv(SageMakerPlatform.PLATFORM_PORT, None): + return int(os.getenv(SageMakerPlatform.PLATFORM_PORT)) + if os.getenv(SageMakerInference.PORT, None): + return int(os.getenv(SageMakerInference.PORT)) + return 8080