2
2
3
3
import json
4
4
import os
5
+ import random
5
6
from enum import Enum
7
+ from typing import Optional
8
+
9
+
10
+ class SageMakerPlatform (str , Enum ):
11
+ """Simple enum to define environment variables injected by the SageMaker platform."""
12
+
13
+ PLATFORM_PORT = "SAGEMAKER_BIND_TO_PORT"
6
14
7
15
8
16
class SageMakerInference (str , Enum ):
9
- """Simple enum to define the mapping between dictionary key and environement variable."""
17
+ """Simple enum to define the mapping between dictionary key and environment variable."""
10
18
11
19
BASE_DIRECTORY = "SAGEMAKER_INFERENCE_BASE_DIRECTORY"
12
20
REQUIREMENTS = "SAGEMAKER_INFERENCE_REQUIREMENTS"
13
21
CODE_DIRECTORY = "SAGEMAKER_INFERENCE_CODE_DIRECTORY"
14
22
CODE = "SAGEMAKER_INFERENCE_CODE"
15
23
LOG_LEVEL = "SAGEMAKER_INFERENCE_LOG_LEVEL"
16
24
PORT = "SAGEMAKER_INFERENCE_PORT"
25
+ SAFE_PORT = "SAGEMAKER_SAFE_PORT_RANGE"
17
26
18
27
19
28
class Environment :
@@ -28,7 +37,8 @@ def __init__(self):
28
37
SageMakerInference .CODE_DIRECTORY : os .getenv (SageMakerInference .CODE_DIRECTORY , None ),
29
38
SageMakerInference .CODE : os .getenv (SageMakerInference .CODE , "inference.handler" ),
30
39
SageMakerInference .LOG_LEVEL : os .getenv (SageMakerInference .LOG_LEVEL , 10 ),
31
- SageMakerInference .PORT : 8080 ,
40
+ SageMakerInference .PORT : self ._resolve_port (),
41
+ SageMakerInference .SAFE_PORT : self ._resolve_from_safe_port_range (),
32
42
}
33
43
34
44
def __str__ (self ):
@@ -57,3 +67,25 @@ def logging_level(self):
57
67
@property
58
68
def port (self ):
59
69
return self ._environment_variables .get (SageMakerInference .PORT )
70
+
71
+ @property
72
+ def safe_port (self ):
73
+ return self ._environment_variables .get (SageMakerInference .SAFE_PORT )
74
+
75
+ def _resolve_port (self ) -> int :
76
+ if os .getenv (SageMakerPlatform .PLATFORM_PORT , None ):
77
+ return int (os .getenv (SageMakerPlatform .PLATFORM_PORT ))
78
+ if os .getenv (SageMakerInference .PORT , None ):
79
+ return int (os .getenv (SageMakerInference .PORT ))
80
+ return 8080
81
+
82
+ def _resolve_from_safe_port_range (self ) -> Optional [int ]:
83
+ safe_port_range = os .getenv (SageMakerInference .SAFE_PORT , None )
84
+ if safe_port_range :
85
+ lower , upper = safe_port_range .split ("-" )
86
+ if not upper :
87
+ return None
88
+
89
+ return random .randint (int (lower ), int (upper ))
90
+
91
+ return None
0 commit comments