Skip to content

Commit a5fdd54

Browse files
committed
add custom inference port control to enable accept-bind-to-port
1 parent 1a8482b commit a5fdd54

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

template/v3/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ ARG PINNED_ENV_IN_FILENAME
77
ARG ARG_BASED_ENV_IN_FILENAME
88
ARG IMAGE_VERSION
99
LABEL "org.amazon.sagemaker-distribution.image.version"=$IMAGE_VERSION
10+
LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true
1011

1112
ARG AMZN_BASE="/opt/amazon/sagemaker"
1213
ARG DB_ROOT_DIR="/opt/db"

template/v3/dirs/etc/sagemaker-inference-server/utils/environment.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,14 @@
55
from enum import Enum
66

77

8+
class SageMakerPlatform(str, Enum):
9+
"""Simple enum to define environment variables injected by the SageMaker platform."""
10+
11+
PLATFORM_PORT = "SAGEMAKER_BIND_TO_PORT"
12+
13+
814
class SageMakerInference(str, Enum):
9-
"""Simple enum to define the mapping between dictionary key and environement variable."""
15+
"""Simple enum to define the mapping between dictionary key and environment variable."""
1016

1117
BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY"
1218
REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS"
@@ -28,7 +34,7 @@ def __init__(self):
2834
SageMakerInference.CODE_DIRECTORY: os.getenv(SageMakerInference.CODE_DIRECTORY, None),
2935
SageMakerInference.CODE: os.getenv(SageMakerInference.CODE, "inference.handler"),
3036
SageMakerInference.LOG_LEVEL: os.getenv(SageMakerInference.LOG_LEVEL, 10),
31-
SageMakerInference.PORT: 8080,
37+
SageMakerInference.PORT: self._resolve_port(),
3238
}
3339

3440
def __str__(self):
@@ -57,3 +63,10 @@ def logging_level(self):
5763
@property
5864
def port(self):
5965
return self._environment_variables.get(SageMakerInference.PORT)
66+
67+
def _resolve_port(self) -> int:
68+
if os.getenv(SageMakerPlatform.PLATFORM_PORT, None):
69+
return int(os.getenv(SageMakerPlatform.PLATFORM_PORT))
70+
if os.getenv(SageMakerInference.PORT, None):
71+
return int(os.getenv(SageMakerInference.PORT))
72+
return 8080

0 commit comments

Comments
 (0)