5
5
from enum import Enum
6
6
7
7
8
+ class SageMakerPlatform (str , Enum ):
9
+ """Simple enum to define environment variables injected by the SageMaker platform."""
10
+
11
+ PLATFORM_PORT = "SAGEMAKER_BIND_TO_PORT"
12
+
13
+
8
14
class SageMakerInference (str , Enum ):
9
- """Simple enum to define the mapping between dictionary key and environement variable."""
15
+ """Simple enum to define the mapping between dictionary key and environment variable."""
10
16
11
17
BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY"
12
18
REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS"
@@ -28,7 +34,7 @@ def __init__(self):
28
34
SageMakerInference .CODE_DIRECTORY : os .getenv (SageMakerInference .CODE_DIRECTORY , None ),
29
35
SageMakerInference .CODE : os .getenv (SageMakerInference .CODE , "inference.handler" ),
30
36
SageMakerInference .LOG_LEVEL : os .getenv (SageMakerInference .LOG_LEVEL , 10 ),
31
- SageMakerInference .PORT : 8080 ,
37
+ SageMakerInference .PORT : self . _resolve_port () ,
32
38
}
33
39
34
40
def __str__ (self ):
@@ -57,3 +63,10 @@ def logging_level(self):
57
63
@property
58
64
def port (self ):
59
65
return self ._environment_variables .get (SageMakerInference .PORT )
66
+
67
+ def _resolve_port (self ) -> int :
68
+ if os .getenv (SageMakerPlatform .PLATFORM_PORT , None ):
69
+ return int (os .getenv (SageMakerPlatform .PLATFORM_PORT ))
70
+ if os .getenv (SageMakerInference .PORT , None ):
71
+ return int (os .getenv (SageMakerInference .PORT ))
72
+ return 8080
0 commit comments