Skip to content

Commit 906ac34

Browse files
jiapinwrvasahu-amazon
authored andcommitted
Initial commit for kv-cache and intelligent support in inference SDK/CLI (#252)
* Initial commit for kv-cache and intelligent support in inference SDK * Initial commit for kv-cache & intelligent routing CLI changes * [Fix] clean up test code * Separate out kv-cache and intelligent routing spec update into v1_1 inference cli * Fix ut and black format * Fix template.py extra white lines
1 parent a5d2a08 commit 906ac34

File tree

13 files changed

+1685
-177
lines changed

13 files changed

+1685
-177
lines changed

hyperpod-custom-inference-template/hyperpod_custom_inference_template/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
# or in the "license" file accompanying this file. This file is
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
12-
# language governing permissions and limitations under the License.
12+
# language governing permissions and limitations under the License.

hyperpod-custom-inference-template/hyperpod_custom_inference_template/registry.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,18 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
from hyperpod_custom_inference_template.v1_0 import model as v1
14-
from hyperpod_custom_inference_template.v1_0.template import TEMPLATE_CONTENT as v1_template
13+
from hyperpod_custom_inference_template.v1_0 import model as v1_0
14+
from hyperpod_custom_inference_template.v1_1 import model as v1_1
15+
from hyperpod_custom_inference_template.v1_0.template import (
16+
TEMPLATE_CONTENT as v1_0_template,
17+
)
18+
from hyperpod_custom_inference_template.v1_1.template import (
19+
TEMPLATE_CONTENT as v1_1_template,
20+
)
1521

1622
SCHEMA_REGISTRY = {
17-
"1.0": v1.FlatHPEndpoint,
23+
"1.0": v1_0.FlatHPEndpoint,
24+
"1.1": v1_1.FlatHPEndpoint,
1825
}
1926

20-
TEMPLATE_REGISTRY = {
21-
"1.0": v1_template
22-
}
27+
TEMPLATE_REGISTRY = {"1.0": v1_0_template, "1.1": v1_1_template}

hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
# or in the "license" file accompanying this file. This file is
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
12-
# language governing permissions and limitations under the License.
12+
# language governing permissions and limitations under the License.

hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/model.py

Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
Worker,
2828
Dimensions,
2929
AutoScalingSpec,
30-
CloudWatchTrigger
30+
CloudWatchTrigger,
3131
)
3232
from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint
3333
from sagemaker.hyperpod.common.config.metadata import Metadata
@@ -37,12 +37,10 @@ class FlatHPEndpoint(BaseModel):
3737
model_config = ConfigDict(extra="forbid")
3838

3939
namespace: Optional[str] = Field(
40-
default=None,
41-
description="Kubernetes namespace",
42-
min_length=1
40+
default=None, description="Kubernetes namespace", min_length=1
4341
)
4442

45-
metadata_name: Optional[str] = Field(
43+
metadata_name: Optional[str] = Field(
4644
None,
4745
alias="metadata_name",
4846
description="Name of the custom endpoint object",
@@ -75,14 +73,15 @@ class FlatHPEndpoint(BaseModel):
7573

7674
# metrics.*
7775
metrics_enabled: Optional[bool] = Field(
78-
False, alias="metrics_enabled",
76+
False,
77+
alias="metrics_enabled",
7978
description="Enable metrics collection",
8079
)
8180

8281
# model_name and version
8382
model_name: str = Field(
84-
...,
85-
alias="model_name",
83+
...,
84+
alias="model_name",
8685
description="Name of model to create on SageMaker",
8786
min_length=1,
8887
max_length=63,
@@ -100,15 +99,18 @@ class FlatHPEndpoint(BaseModel):
10099

101100
# model_source_config.*
102101
model_source_type: Literal["fsx", "s3"] = Field(
103-
..., alias="model_source_type",
102+
...,
103+
alias="model_source_type",
104104
description="Source type: fsx or s3",
105105
)
106106
model_location: Optional[str] = Field(
107-
None, alias="model_location",
107+
None,
108+
alias="model_location",
108109
description="Specific model data location",
109110
)
110111
prefetch_enabled: Optional[bool] = Field(
111-
False, alias="prefetch_enabled",
112+
False,
113+
alias="prefetch_enabled",
112114
description="Whether to pre-fetch model data",
113115
)
114116

@@ -122,11 +124,12 @@ class FlatHPEndpoint(BaseModel):
122124

123125
# worker.*
124126
image_uri: str = Field(
125-
..., alias="image_uri",
127+
...,
128+
alias="image_uri",
126129
description="Inference server image name",
127130
)
128131
container_port: int = Field(
129-
...,
132+
...,
130133
alias="container_port",
131134
description="Port on which the model server listens",
132135
ge=1,
@@ -138,7 +141,8 @@ class FlatHPEndpoint(BaseModel):
138141
description="Path inside container for model volume",
139142
)
140143
model_volume_mount_name: str = Field(
141-
..., alias="model_volume_mount_name",
144+
...,
145+
alias="model_volume_mount_name",
142146
description="Name of the model volume mount",
143147
)
144148

@@ -149,7 +153,7 @@ class FlatHPEndpoint(BaseModel):
149153
description="FSX File System DNS Name",
150154
)
151155
fsx_file_system_id: Optional[str] = Field(
152-
None,
156+
None,
153157
alias="fsx_file_system_id",
154158
description="FSX File System ID",
155159
)
@@ -161,23 +165,23 @@ class FlatHPEndpoint(BaseModel):
161165

162166
# S3Storage
163167
s3_bucket_name: Optional[str] = Field(
164-
None,
168+
None,
165169
alias="s3_bucket_name",
166170
description="S3 bucket location",
167171
)
168172
s3_region: Optional[str] = Field(
169-
None,
173+
None,
170174
alias="s3_region",
171175
description="S3 bucket region",
172176
)
173177

174178
# Resources
175-
resources_limits: Optional[Dict[str, Union[int,str]]] = Field(
179+
resources_limits: Optional[Dict[str, Union[int, str]]] = Field(
176180
None,
177181
alias="resources_limits",
178182
description="Resource limits for the worker",
179183
)
180-
resources_requests: Optional[Dict[str, Union[int,str]]] = Field(
184+
resources_requests: Optional[Dict[str, Union[int, str]]] = Field(
181185
None,
182186
alias="resources_requests",
183187
description="Resource requests for the worker",
@@ -187,84 +191,82 @@ class FlatHPEndpoint(BaseModel):
187191
dimensions: Optional[Dict[str, str]] = Field(
188192
None,
189193
alias="dimensions",
190-
description="CloudWatch Metric dimensions as key–value pairs"
194+
description="CloudWatch Metric dimensions as key–value pairs",
191195
)
192196

193197
# CloudWatch Trigger
194198
metric_collection_period: Optional[int] = Field(
195-
300,
196-
description="Defines the Period for CloudWatch query"
199+
300, description="Defines the Period for CloudWatch query"
197200
)
198201
metric_collection_start_time: Optional[int] = Field(
199-
300,
200-
description="Defines the StartTime for CloudWatch query"
202+
300, description="Defines the StartTime for CloudWatch query"
201203
)
202204
metric_name: Optional[str] = Field(
203-
None,
204-
description="Metric name to query for CloudWatch trigger"
205+
None, description="Metric name to query for CloudWatch trigger"
205206
)
206207
metric_stat: Optional[str] = Field(
207208
"Average",
208209
description=(
209210
"Statistics metric to be used by Trigger. "
210211
"Defines the Stat for the CloudWatch query. Default is Average."
211-
)
212+
),
212213
)
213214
metric_type: Optional[Literal["Value", "Average"]] = Field(
214215
"Average",
215216
description=(
216217
"The type of metric to be used by HPA. "
217218
"`Average` – Uses average value per pod; "
218219
"`Value` – Uses absolute metric value."
219-
)
220+
),
220221
)
221222
min_value: Optional[float] = Field(
222223
0,
223224
description=(
224225
"Minimum metric value used in case of empty response "
225226
"from CloudWatch. Default is 0."
226-
)
227+
),
227228
)
228229
cloud_watch_trigger_name: Optional[str] = Field(
229-
None,
230-
description="Name for the CloudWatch trigger"
230+
None, description="Name for the CloudWatch trigger"
231231
)
232232
cloud_watch_trigger_namespace: Optional[str] = Field(
233-
None,
234-
description="AWS CloudWatch namespace for the metric"
233+
None, description="AWS CloudWatch namespace for the metric"
235234
)
236235
target_value: Optional[float] = Field(
237-
None,
238-
description="Target value for the CloudWatch metric"
236+
None, description="Target value for the CloudWatch metric"
239237
)
240238
use_cached_metrics: Optional[bool] = Field(
241239
True,
242240
description=(
243241
"Enable caching of metric values during polling interval. "
244242
"Default is true."
245-
)
243+
),
246244
)
247245

248246
invocation_endpoint: Optional[str] = Field(
249247
default="invocations",
250248
description=(
251249
"The invocation endpoint of the model server. http://<host>:<port>/ would be pre-populated based on the other fields. "
252250
"Please fill in the path after http://<host>:<port>/ specific to your model server.",
253-
)
251+
),
254252
)
255253

256-
@model_validator(mode='after')
254+
@model_validator(mode="after")
257255
def validate_model_source_config(self):
258256
"""Validate that required fields are provided based on model_source_type"""
259257
if self.model_source_type == "s3":
260258
if not self.s3_bucket_name or not self.s3_region:
261-
raise ValueError("s3_bucket_name and s3_region are required when model_source_type is 's3'")
259+
raise ValueError(
260+
"s3_bucket_name and s3_region are required when model_source_type is 's3'"
261+
)
262262
elif self.model_source_type == "fsx":
263263
if not self.fsx_file_system_id:
264-
raise ValueError("fsx_file_system_id is required when model_source_type is 'fsx'")
264+
raise ValueError(
265+
"fsx_file_system_id is required when model_source_type is 'fsx'"
266+
)
265267
return self
266268

267-
@model_validator(mode='after')
269+
@model_validator(mode="after")
268270
def validate_name(self):
269271
if not self.metadata_name and not self.endpoint_name:
270272
raise ValueError("Either metadata_name or endpoint_name must be provided")
@@ -273,21 +275,20 @@ def validate_name(self):
273275
def to_domain(self) -> HPEndpoint:
274276
if self.endpoint_name and not self.metadata_name:
275277
self.metadata_name = self.endpoint_name
276-
278+
277279
metadata = Metadata(name=self.metadata_name, namespace=self.namespace)
278280

279281
env_vars = None
280282
if self.env:
281283
env_vars = [
282-
EnvironmentVariables(name=k, value=v)
283-
for k, v in self.env.items()
284+
EnvironmentVariables(name=k, value=v) for k, v in self.env.items()
284285
]
285286

286287
dim_vars: list[Dimensions] = []
287288
if self.dimensions:
288289
for name, value in self.dimensions.items():
289290
dim_vars.append(Dimensions(name=name, value=value))
290-
291+
291292
cloud_watch_trigger = CloudWatchTrigger(
292293
dimensions=dim_vars,
293294
metric_collection_period=self.metric_collection_period,
@@ -300,12 +301,10 @@ def to_domain(self) -> HPEndpoint:
300301
namespace=self.cloud_watch_trigger_namespace,
301302
target_value=self.target_value,
302303
use_cached_metrics=self.use_cached_metrics,
303-
)
304-
305-
auto_scaling_spec = AutoScalingSpec(
306-
cloud_watch_trigger = cloud_watch_trigger
307304
)
308305

306+
auto_scaling_spec = AutoScalingSpec(cloud_watch_trigger=cloud_watch_trigger)
307+
309308
# nested metrics
310309
metrics = Metrics(
311310
enabled=self.metrics_enabled,
@@ -336,7 +335,9 @@ def to_domain(self) -> HPEndpoint:
336335
fsx_storage=fsx,
337336
)
338337

339-
tls = TlsConfig(tls_certificate_output_s3_uri=self.tls_certificate_output_s3_uri)
338+
tls = TlsConfig(
339+
tls_certificate_output_s3_uri=self.tls_certificate_output_s3_uri
340+
)
340341

341342
invocation_port = ModelInvocationPort(
342343
container_port=self.container_port,
@@ -368,4 +369,4 @@ def to_domain(self) -> HPEndpoint:
368369
worker=worker,
369370
invocation_endpoint=self.invocation_endpoint,
370371
auto_scaling_spec=auto_scaling_spec
371-
)
372+
)

hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,4 @@
8585
8686
invocationEndpoint: {{ invocation_endpoint }}
8787
88-
"""
88+
"""
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.

0 commit comments

Comments
 (0)