22
33import json
44import os
5+ import random
56from enum import Enum
7+ from typing import Optional
8+
9+
10+ class SageMakerPlatform (str , Enum ):
11+ """Simple enum to define environment variables injected by the SageMaker platform."""
12+
13+ PLATFORM_PORT = "SAGEMAKER_BIND_TO_PORT"
614
715
816class SageMakerInference (str , Enum ):
9- """Simple enum to define the mapping between dictionary key and environement variable."""
17+ """Simple enum to define the mapping between dictionary key and environment variable."""
1018
1119 BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY"
1220 REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS"
1321 CODE_DIRECTORY = "SAGEMAKER_INFERENCE_CODE_DIRECTORY"
1422 CODE = "SAGEMAKER_INFERENCE_CODE"
1523 LOG_LEVEL = "SAGEMAKER_INFERENCE_LOG_LEVEL"
1624 PORT = "SAGEMAKER_INFERENCE_PORT"
25+ SAFE_PORT = "SAGEMAKER_SAFE_PORT_RANGE"
1726
1827
1928class Environment :
@@ -28,7 +37,8 @@ def __init__(self):
2837 SageMakerInference .CODE_DIRECTORY : os .getenv (SageMakerInference .CODE_DIRECTORY , None ),
2938 SageMakerInference .CODE : os .getenv (SageMakerInference .CODE , "inference.handler" ),
3039 SageMakerInference .LOG_LEVEL : os .getenv (SageMakerInference .LOG_LEVEL , 10 ),
31- SageMakerInference .PORT : 8080 ,
40+ SageMakerInference .PORT : self ._resolve_port (),
41+ SageMakerInference .SAFE_PORT : self ._resolve_from_safe_port_range (),
3242 }
3343
3444 def __str__ (self ):
@@ -57,3 +67,25 @@ def logging_level(self):
5767 @property
5868 def port (self ):
5969 return self ._environment_variables .get (SageMakerInference .PORT )
70+
71+ @property
72+ def safe_port (self ):
73+ return self ._environment_variables .get (SageMakerInference .SAFE_PORT )
74+
75+ def _resolve_port (self ) -> int :
76+ if os .getenv (SageMakerPlatform .PLATFORM_PORT , None ):
77+ return int (os .getenv (SageMakerPlatform .PLATFORM_PORT ))
78+ if os .getenv (SageMakerInference .PORT , None ):
79+ return int (os .getenv (SageMakerInference .PORT ))
80+ return 8080
81+
82+ def _resolve_from_safe_port_range (self ) -> Optional [int ]:
83+ safe_port_range = os .getenv (SageMakerInference .SAFE_PORT , None )
84+ if safe_port_range :
85+ lower , upper = safe_port_range .split ("-" )
86+ if not upper :
87+ return None
88+
89+ return random .randint (int (lower ), int (upper ))
90+
91+ return None
0 commit comments