Skip to content

Commit 49f07e6

Browse files
committed
add custom inference port control to enable accept-bind-to-port
1 parent 7c3f13b commit 49f07e6

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

template/v3/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ ARG PINNED_ENV_IN_FILENAME
77
ARG ARG_BASED_ENV_IN_FILENAME
88
ARG IMAGE_VERSION
99
LABEL "org.amazon.sagemaker-distribution.image.version"=$IMAGE_VERSION
10+
LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true
1011

1112
ARG AMZN_BASE="/opt/amazon/sagemaker"
1213
ARG DB_ROOT_DIR="/opt/db"

template/v3/dirs/etc/sagemaker-inference-server/tornado_server/async_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,7 @@ async def handle(handler: callable, environment: Environment):
7272
]
7373
)
7474
app.listen(environment.port)
75+
if environment.safe_port:
76+
app.listen(environment.safe_port)
7577
logger.debug(f"Asynchronous inference server listening on port: `{environment.port}`")
7678
await asyncio.Event().wait()

template/v3/dirs/etc/sagemaker-inference-server/tornado_server/sync_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,7 @@ async def handle(handler: callable, environment: Environment):
7373
]
7474
)
7575
app.listen(environment.port)
76+
if environment.safe_port:
77+
app.listen(environment.safe_port)
7678
logger.debug(f"Synchronous inference server listening on port: `{environment.port}`")
7779
await asyncio.Event().wait()

template/v3/dirs/etc/sagemaker-inference-server/utils/environment.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,27 @@
22

33
import json
44
import os
5+
import random
56
from enum import Enum
7+
from typing import Optional
8+
9+
10+
class SageMakerPlatform(str, Enum):
11+
"""Simple enum to define environment variables injected by the SageMaker platform."""
12+
13+
PLATFORM_PORT = "SAGEMAKER_BIND_TO_PORT"
614

715

816
class SageMakerInference(str, Enum):
9-
"""Simple enum to define the mapping between dictionary key and environement variable."""
17+
"""Simple enum to define the mapping between dictionary key and environment variable."""
1018

1119
BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY"
1220
REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS"
1321
CODE_DIRECTORY = "SAGEMAKER_INFERENCE_CODE_DIRECTORY"
1422
CODE = "SAGEMAKER_INFERENCE_CODE"
1523
LOG_LEVEL = "SAGEMAKER_INFERENCE_LOG_LEVEL"
1624
PORT = "SAGEMAKER_INFERENCE_PORT"
25+
SAFE_PORT = "SAGEMAKER_SAFE_PORT_RANGE"
1726

1827

1928
class Environment:
@@ -28,7 +37,8 @@ def __init__(self):
2837
SageMakerInference.CODE_DIRECTORY: os.getenv(SageMakerInference.CODE_DIRECTORY, None),
2938
SageMakerInference.CODE: os.getenv(SageMakerInference.CODE, "inference.handler"),
3039
SageMakerInference.LOG_LEVEL: os.getenv(SageMakerInference.LOG_LEVEL, 10),
31-
SageMakerInference.PORT: 8080,
40+
SageMakerInference.PORT: self._resolve_port(),
41+
SageMakerInference.SAFE_PORT: self._resolve_from_safe_port_range(),
3242
}
3343

3444
def __str__(self):
@@ -57,3 +67,25 @@ def logging_level(self):
5767
@property
5868
def port(self):
5969
return self._environment_variables.get(SageMakerInference.PORT)
70+
71+
@property
72+
def safe_port(self):
73+
return self._environment_variables.get(SageMakerInference.SAFE_PORT)
74+
75+
def _resolve_port(self) -> int:
76+
if os.getenv(SageMakerPlatform.PLATFORM_PORT, None):
77+
return int(os.getenv(SageMakerPlatform.PLATFORM_PORT))
78+
if os.getenv(SageMakerInference.PORT, None):
79+
return int(os.getenv(SageMakerInference.PORT))
80+
return 8080
81+
82+
def _resolve_from_safe_port_range(self) -> Optional[int]:
83+
safe_port_range = os.getenv(SageMakerInference.SAFE_PORT, None)
84+
if safe_port_range:
85+
lower, upper = safe_port_range.split("-")
86+
if not upper:
87+
return None
88+
89+
return random.randint(int(lower), int(upper))
90+
91+
return None

0 commit comments

Comments
 (0)