Skip to content

Commit b07f502

Browse files
committed
sagemaker module
1 parent f2636f2 commit b07f502

File tree

4 files changed

+109
-1
lines changed

4 files changed

+109
-1
lines changed

awswrangler/sagemaker.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pickle
2+
import tarfile
3+
import logging
4+
from typing import Any
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
class SageMaker:
10+
def __init__(self, session):
11+
self._session = session
12+
self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config)
13+
14+
@staticmethod
15+
def _parse_path(path):
16+
path2 = path.replace("s3://", "")
17+
parts = path2.partition("/")
18+
return parts[0], parts[2]
19+
20+
def get_job_outputs(self, path: str) -> Any:
21+
22+
bucket, key = SageMaker._parse_path(path)
23+
if key.split("/")[-1] != "model.tar.gz":
24+
key = f"{key}/model.tar.gz"
25+
body = self._client_s3.get_object(Bucket=bucket, Key=key)["Body"].read()
26+
body = tarfile.io.BytesIO(body)
27+
tar = tarfile.open(fileobj=body)
28+
29+
results = []
30+
for member in tar.getmembers():
31+
f = tar.extractfile(member)
32+
file_type = member.name.split(".")[-1]
33+
34+
if file_type == "pkl":
35+
f = pickle.load(f)
36+
37+
results.append(f)
38+
39+
return results

awswrangler/session.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from awswrangler.glue import Glue
1414
from awswrangler.redshift import Redshift
1515
from awswrangler.emr import EMR
16+
from awswrangler.sagemaker import SageMaker
17+
1618

1719
PYSPARK_INSTALLED = False
1820
if importlib.util.find_spec("pyspark"): # type: ignore
@@ -112,6 +114,7 @@ def __init__(self,
112114
self._glue = None
113115
self._redshift = None
114116
self._spark = None
117+
self._sagemaker = None
115118

116119
def _load_new_boto3_session(self):
117120
"""
@@ -281,6 +284,12 @@ def redshift(self):
281284
self._redshift = Redshift(session=self)
282285
return self._redshift
283286

287+
@property
288+
def sagemaker(self):
289+
if not self._sagemaker:
290+
self._sagemaker = SageMaker(session=self)
291+
return self._sagemaker
292+
284293
@property
285294
def spark(self):
286295
if not PYSPARK_INSTALLED:

requirements-dev.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ mypy~=0.750
33
flake8~=3.7.9
44
pytest-cov~=2.8.1
55
cfn-lint~=0.26.0
6+
scikit-learn==0.22
7+
sklearn==0.0
68
twine~=3.1.1
79
wheel~=0.33.6
810
sphinx~=2.2.2
911
pyspark~=2.4.4
10-
pyspark-stubs~=2.4.0.post6
12+
pyspark-stubs~=2.4.0.post6
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
import pickle
3+
import logging
4+
import tarfile
5+
6+
import boto3
7+
import pytest
8+
9+
from awswrangler import Session
10+
from sklearn.linear_model import LinearRegression
11+
12+
logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s")
13+
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
14+
15+
16+
@pytest.fixture(scope="module")
17+
def session():
18+
yield Session()
19+
20+
21+
@pytest.fixture(scope="module")
22+
def cloudformation_outputs():
23+
response = boto3.client("cloudformation").describe_stacks(StackName="aws-data-wrangler-test-arena")
24+
outputs = {}
25+
for output in response.get("Stacks")[0].get("Outputs"):
26+
outputs[output.get("OutputKey")] = output.get("OutputValue")
27+
yield outputs
28+
29+
30+
@pytest.fixture(scope="module")
31+
def bucket(session, cloudformation_outputs):
32+
if "BucketName" in cloudformation_outputs:
33+
bucket = cloudformation_outputs["BucketName"]
34+
session.s3.delete_objects(path=f"s3://{bucket}/")
35+
else:
36+
raise Exception("You must deploy the test infrastructure using Cloudformation!")
37+
yield bucket
38+
session.s3.delete_objects(path=f"s3://{bucket}/")
39+
40+
41+
def test_get_job_outputs(session, bucket):
42+
model_path = "output"
43+
s3 = boto3.resource("s3")
44+
45+
lr = LinearRegression()
46+
with open("model.pkl", "wb") as fp:
47+
pickle.dump(lr, fp, pickle.HIGHEST_PROTOCOL)
48+
49+
with tarfile.open("model.tar.gz", "w:gz") as tar:
50+
tar.add("model.pkl")
51+
52+
s3.Bucket(bucket).upload_file("model.tar.gz", f"{model_path}/model.tar.gz")
53+
outputs = session.sagemaker.get_job_outputs(f"{bucket}/{model_path}")
54+
55+
os.remove("model.pkl")
56+
os.remove("model.tar.gz")
57+
58+
assert type(outputs[0]) == LinearRegression

0 commit comments

Comments
 (0)