Skip to content

Commit b977f44

Browse files
committed
Adapt SageMaker to asynchronous endpoint creation
1 parent 4fbabe5 commit b977f44

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

sagemaker-inference/ReadMe.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
This is a small example about how you can use LocalStack to host your PyTorch ML models.
44

5-
Before using this example you should setup your Docker Client to pull the AWS Deep Learning images ([more info here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md)):
5+
Before using this example you should set up your Docker Client to pull the AWS Deep Learning images ([more info here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md)):
66

77
```bash
88
aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-east-1.amazonaws.com
99
```
1010

11-
Because the images tend to be really big (multiple GB), you might want to `docker pull` them beforehand to avoid any timeouts.
11+
Because the images tend to be heavy (multiple GB), you might want to `docker pull` them beforehand.

sagemaker-inference/main.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,39 @@
2828

2929
def deploy_model(run_id: str = "0"):
3030
# Put the Model into the correct bucket
31+
print("Creating bucket...")
3132
s3.create_bucket(Bucket=f"{MODEL_BUCKET}-{run_id}")
33+
print("Uploading model data to bucket...")
3234
s3.upload_file(MODEL_TAR, f"{MODEL_BUCKET}-{run_id}", f"{MODEL_NAME}.tar.gz")
3335

3436
# Create the model in sagemaker
37+
print("Creating model in SageMaker...")
3538
sagemaker.create_model(ModelName=f"{MODEL_NAME}-{run_id}", ExecutionRoleArn=EXECUTION_ROLE_ARN,
3639
PrimaryContainer={"Image": CONTAINER_IMAGE,
3740
"ModelDataUrl": f"s3://{MODEL_BUCKET}-{run_id}/{MODEL_NAME}.tar.gz"})
41+
print("Adding endpoint configuration...")
3842
sagemaker.create_endpoint_config(EndpointConfigName=f"{CONFIG_NAME}-{run_id}", ProductionVariants=[{
3943
"VariantName": f"var-{run_id}", "ModelName": f"{MODEL_NAME}-{run_id}", "InitialInstanceCount": 1,
4044
"InstanceType": "ml.m5.large"
4145
}])
46+
print("Creating endpoint...")
4247
sagemaker.create_endpoint(EndpointName=f"{ENDPOINT_NAME}-{run_id}", EndpointConfigName=f"{CONFIG_NAME}-{run_id}")
4348

4449

50+
def await_endpoint(run_id: str = "0", wait: float = 0.5, max_retries=10, _retries: int = 0):
51+
print("Checking endpoint status...")
52+
status = sagemaker.describe_endpoint(EndpointName=f"{ENDPOINT_NAME}-{run_id}")["EndpointStatus"]
53+
if status == "InService":
54+
print("Endpoint ready!")
55+
return True
56+
if _retries == max_retries:
57+
print("Endpoint unreachable!")
58+
return False
59+
print("Endpoint not ready - waiting...")
60+
time.sleep(wait)
61+
return await_endpoint(run_id, wait * 2, max_retries, _retries + 1)
62+
63+
4564
def _get_input_dict():
4665
X, Y = mnist_to_numpy("data/mnist", train=False)
4766
mask = random.sample(range(X.shape[0]), 2)
@@ -67,13 +86,15 @@ def inference_model_container(run_id: str = "0"):
6786
if tag["Key"] == "_LS_ENDPOINT_PORT_":
6887
port = tag["Value"]
6988
inputs = _get_input_dict()
89+
print("Invoking endpoint directly...")
7090
response = httpx.post(f"http://localhost.localstack.cloud:{port}/invocations", json=inputs,
7191
headers={"Content-Type": "application/json", "Accept": "application/json"})
7292
_show_predictions(json.loads(response.text))
7393

7494

7595
def inference_model_boto3(run_id: str = "0"):
7696
inputs = _get_input_dict()
97+
print("Invoking via boto...")
7798
response = sagemaker_runtime.invoke_endpoint(EndpointName=f"{ENDPOINT_NAME}-{run_id}", Body=json.dumps(inputs),
7899
Accept="application/json",
79100
ContentType="application/json")
@@ -89,8 +110,7 @@ def _short_uid():
89110
if __name__ == '__main__':
90111
test_run = _short_uid()
91112
deploy_model(test_run)
92-
# wait some time to avoid connection resets in log output
93-
# -> not essential as the container spins up quickly enough within the retries of boto
94-
time.sleep(2)
113+
if not await_endpoint(test_run):
114+
exit(-1)
95115
inference_model_boto3(test_run)
96116
inference_model_container(test_run)

0 commit comments

Comments
 (0)