Skip to content

Commit 416bd08

Browse files
authored
add SageMaker inference example (#152)
1 parent b609bd8 commit 416bd08

File tree

7 files changed

+226
-0
lines changed

7 files changed

+226
-0
lines changed

sagemaker-inference/ReadMe.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# SageMaker Model Inference
2+
3+
This is a small example about how you can use LocalStack to host your PyTorch ML models.
4+
5+
Before using this example you should setup your Docker Client to pull the AWS Deep Learning images ([more info here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md)):
6+
7+
```bash
8+
aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-east-1.amazonaws.com
9+
```
1.57 MB
Binary file not shown.
4.44 KB
Binary file not shown.
81.6 KB
Binary file not shown.

sagemaker-inference/main.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import json
2+
import random
3+
import time
4+
5+
import boto3
6+
import httpx
7+
import numpy as np
8+
from mypy_boto3_s3 import S3Client
9+
from mypy_boto3_sagemaker import SageMakerClient
10+
from mypy_boto3_sagemaker_runtime import SageMakerRuntimeClient
11+
12+
from mnist import mnist_to_numpy, normalize
13+
14+
LOCALSTACK_ENDPOINT = "http://localhost.localstack.cloud:4566"
15+
MODEL_BUCKET = "models"
16+
MODEL_TAR = "./data/model.tar.gz"
17+
MODEL_NAME = "sample"
18+
CONFIG_NAME = "sample-cf"
19+
ENDPOINT_NAME = "sample-ep"
20+
CONTAINER_IMAGE = "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:1.5.0-cpu-py3"
21+
EXECUTION_ROLE_ARN = "arn:aws:iam::0000000000000:role/sagemaker-role"
22+
23+
sagemaker: SageMakerClient = boto3.client("sagemaker", endpoint_url=LOCALSTACK_ENDPOINT)
24+
sagemaker_runtime: SageMakerRuntimeClient = boto3.client("sagemaker-runtime", endpoint_url=LOCALSTACK_ENDPOINT)
25+
s3: S3Client = boto3.client("s3", endpoint_url=LOCALSTACK_ENDPOINT)
26+
27+
28+
def deploy_model(run_id: str = "0"):
29+
# Put the Model into the correct bucket
30+
s3.create_bucket(Bucket=f"{MODEL_BUCKET}-{run_id}")
31+
s3.upload_file(MODEL_TAR, f"{MODEL_BUCKET}-{run_id}", f"{MODEL_NAME}.tar.gz")
32+
33+
# Create the model in sagemaker
34+
sagemaker.create_model(ModelName=f"{MODEL_NAME}-{run_id}", ExecutionRoleArn=EXECUTION_ROLE_ARN,
35+
PrimaryContainer={"Image": CONTAINER_IMAGE,
36+
"ModelDataUrl": f"s3://{MODEL_BUCKET}-{run_id}/{MODEL_NAME}.tar.gz"})
37+
sagemaker.create_endpoint_config(EndpointConfigName=f"{CONFIG_NAME}-{run_id}", ProductionVariants=[{
38+
"VariantName": f"var-{run_id}", "ModelName": f"{MODEL_NAME}-{run_id}", "InitialInstanceCount": 1,
39+
"InstanceType": "ml.m5.large"
40+
}])
41+
sagemaker.create_endpoint(EndpointName=f"{ENDPOINT_NAME}-{run_id}", EndpointConfigName=f"{CONFIG_NAME}-{run_id}")
42+
43+
44+
def _get_input_dict():
45+
X, Y = mnist_to_numpy("data/mnist", train=False)
46+
mask = random.sample(range(X.shape[0]), 2)
47+
samples = X[mask]
48+
49+
samples = normalize(samples.astype(np.float32), axis=(1, 2))
50+
return {
51+
"inputs": np.expand_dims(samples, axis=1).tolist()
52+
}
53+
54+
55+
def _show_predictions(response):
56+
predictions = np.argmax(np.array(response, dtype=np.float32), axis=1).tolist()
57+
print(f"Predicted digits: {predictions}")
58+
59+
60+
def inference_model_container(run_id: str = "0"):
61+
ep = sagemaker.describe_endpoint(EndpointName=f"{ENDPOINT_NAME}-{run_id}")
62+
arn = ep["EndpointArn"]
63+
tag_list = sagemaker.list_tags(ResourceArn=arn)
64+
port = "4510"
65+
for tag in tag_list["Tags"]:
66+
if tag["Key"] == "_LS_ENDPOINT_PORT_":
67+
port = tag["Value"]
68+
inputs = _get_input_dict()
69+
response = httpx.post(f"http://localhost.localstack.cloud:{port}/invocations", json=inputs,
70+
headers={"Content-Type": "application/json", "Accept": "application/json"})
71+
_show_predictions(json.loads(response.text))
72+
73+
74+
def inference_model_boto3(run_id: str = "0"):
75+
inputs = _get_input_dict()
76+
response = sagemaker_runtime.invoke_endpoint(EndpointName=f"{ENDPOINT_NAME}-{run_id}", Body=json.dumps(inputs),
77+
Accept="application/json",
78+
ContentType="application/json")
79+
_show_predictions(json.loads(response["Body"].read()))
80+
81+
82+
def _short_uid():
83+
import uuid
84+
85+
return str(uuid.uuid4())[:8]
86+
87+
88+
if __name__ == '__main__':
89+
test_run = _short_uid()
90+
deploy_model(test_run)
91+
# wait some time to avoid connection resets in log output
92+
# -> not essential as the container spins up quickly enough within the retries of boto
93+
time.sleep(2)
94+
inference_model_boto3(test_run)
95+
inference_model_container(test_run)

sagemaker-inference/mnist.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import gzip
2+
import os
3+
4+
import boto3
5+
import numpy as np
6+
7+
dirname = os.path.dirname(os.path.abspath(__file__))
8+
9+
10+
def mnist_to_numpy(data_dir="/tmp/data", train=True):
11+
"""Download MNIST dataset and convert it to numpy array
12+
13+
Args:
14+
data_dir (str): directory to save the data
15+
train (bool): download training set
16+
17+
Returns:
18+
tuple of images and labels as numpy arrays
19+
"""
20+
21+
if not os.path.exists(data_dir):
22+
os.makedirs(data_dir)
23+
24+
if train:
25+
images_file = "train-images-idx3-ubyte.gz"
26+
labels_file = "train-labels-idx1-ubyte.gz"
27+
else:
28+
images_file = "t10k-images-idx3-ubyte.gz"
29+
labels_file = "t10k-labels-idx1-ubyte.gz"
30+
31+
# download objects
32+
s3 = boto3.client("s3")
33+
bucket = "sagemaker-sample-files"
34+
for obj in [images_file, labels_file]:
35+
key = os.path.join("datasets/image/MNIST", obj)
36+
dest = os.path.join(data_dir, obj)
37+
if not os.path.exists(dest):
38+
s3.download_file(bucket, key, dest)
39+
40+
return _convert_to_numpy(data_dir, images_file, labels_file)
41+
42+
43+
def _convert_to_numpy(data_dir, images_file, labels_file):
44+
"""Byte string to numpy arrays"""
45+
with gzip.open(os.path.join(data_dir, images_file), "rb") as f:
46+
images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)
47+
48+
with gzip.open(os.path.join(data_dir, labels_file), "rb") as f:
49+
labels = np.frombuffer(f.read(), np.uint8, offset=8)
50+
51+
return (images, labels)
52+
53+
54+
def normalize(x, axis):
55+
eps = np.finfo(float).eps
56+
57+
mean = np.mean(x, axis=axis, keepdims=True)
58+
# avoid division by zero
59+
std = np.std(x, axis=axis, keepdims=True) + eps
60+
return (x - mean) / std
61+
62+
63+
def adjust_to_framework(x, framework="pytorch"):
64+
"""Adjust a ``numpy.ndarray`` to be used as input for specified framework
65+
66+
Args:
67+
x (numpy.ndarray): Batch of images to be adjusted
68+
to follow the convention in pytorch / tensorflow / mxnet
69+
70+
framework (str): Framework to use. Takes value in
71+
``pytorch``, ``tensorflow`` or ``mxnet``
72+
Return:
73+
numpy.ndarray following the convention of tensors in the given
74+
framework
75+
"""
76+
77+
if x.ndim == 3:
78+
# input is gray-scale
79+
x = np.expand_dims(x, 1)
80+
81+
if framework in ["pytorch", "mxnet"]:
82+
# depth-major
83+
return x
84+
elif framework == "tensorlfow":
85+
# depth-minor
86+
return np.transpose(x, (0, 2, 3, 1))
87+
elif framework == "mxnet":
88+
return x
89+
else:
90+
raise ValueError(
91+
"framework must be one of " + "[pytorch, tensorflow, mxnet], got {}".format(framework)
92+
)
93+
94+
95+
if __name__ == "__main__":
96+
X, Y = mnist_to_numpy()
97+
X, Y = X.astype(np.float32), Y.astype(np.int8)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
anyio==3.6.1
2+
boto3==1.24.85
3+
boto3-stubs==1.24.85
4+
botocore==1.27.85
5+
botocore-stubs==1.27.85
6+
certifi==2022.9.24
7+
charset-normalizer==2.1.1
8+
h11==0.12.0
9+
httpcore==0.15.0
10+
httpx==0.23.0
11+
idna==3.4
12+
jmespath==1.0.1
13+
mypy-boto3-s3==1.24.76
14+
mypy-boto3-sagemaker==1.24.84
15+
mypy-boto3-sagemaker-runtime==1.24.84
16+
numpy==1.23.3
17+
python-dateutil==2.8.2
18+
rfc3986==1.5.0
19+
s3transfer==0.6.0
20+
six==1.16.0
21+
sniffio==1.3.0
22+
types-awscrt==0.14.6
23+
types-s3transfer==0.6.0.post4
24+
typing_extensions==4.3.0
25+
urllib3==1.26.12

0 commit comments

Comments
 (0)