Skip to content

Commit 21de8b1

Browse files
authored
Merge pull request #92 from awslabs/sagemaker
Improving Sagemaker module
2 parents 4209c44 + 0095ecd commit 21de8b1

File tree

6 files changed

+148
-18
lines changed

6 files changed

+148
-18
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
* Get EMR step state
6262
* Athena query to receive the result as python primitives (*Iterable[Dict[str, Any]*)
6363
* Load and Unzip SageMaker jobs outputs
64+
* Load and Unzip SageMaker models
6465
* Redshift -> Parquet (S3)
6566
* Aurora -> CSV (S3) (MySQL) (NEW :star:)
6667

@@ -417,6 +418,14 @@ for row in wr.athena.query(query="...", database="..."):
417418
```py3
418419
import awswrangler as wr
419420

421+
outputs = wr.sagemaker.get_model("JOB_NAME")
422+
```
423+
424+
#### Load and unzip SageMaker job output
425+
426+
```py3
427+
import awswrangler as wr
428+
420429
outputs = wr.sagemaker.get_job_outputs("JOB_NAME")
421430
```
422431

awswrangler/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,7 @@ class AWSCredentialsNotFound(Exception):
104104

105105
class InvalidEngine(Exception):
106106
pass
107+
108+
109+
class InvalidSagemakerOutput(Exception):
110+
pass

awswrangler/sagemaker.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Any
1+
from typing import Any, Dict
22
import pickle
33
import tarfile
44
import logging
55

6-
from awswrangler.exceptions import InvalidParameters
6+
from awswrangler.exceptions import InvalidParameters, InvalidSagemakerOutput
77

88
logger = logging.getLogger(__name__)
99

@@ -22,34 +22,68 @@ def _parse_path(path):
2222
parts = path2.partition("/")
2323
return parts[0], parts[2]
2424

25-
def get_job_outputs(self, job_name: str = None, path: str = None) -> Any:
25+
def get_job_outputs(self, job_name: str = None, path: str = None) -> Dict[str, Any]:
26+
"""
27+
Extract and deserialize all Sagemaker's outputs (everything inside model.tar.gz)
28+
29+
:param job_name: Sagemaker's job name
30+
:param path: S3 path (model.tar.gz path)
31+
:return: A Dictionary with all filenames (key) and all objects (values)
32+
"""
2633

2734
if path and job_name:
28-
raise InvalidParameters("Specify either path, job_arn or job_name")
35+
raise InvalidParameters("Specify either path or job_name")
2936

3037
if job_name:
3138
path = self._client_sagemaker.describe_training_job(
3239
TrainingJobName=job_name)["ModelArtifacts"]["S3ModelArtifacts"]
3340

34-
if not self._session.s3.does_object_exists(path):
35-
return None
41+
if path is not None:
42+
if path.split("/")[-1] != "model.tar.gz":
43+
path = f"{path}/model.tar.gz"
3644

37-
bucket, key = SageMaker._parse_path(path)
38-
if key.split("/")[-1] != "model.tar.gz":
39-
key = f"{key}/model.tar.gz"
45+
if self._session.s3.does_object_exists(path) is False:
46+
raise InvalidSagemakerOutput(f"Path does not exists ({path})")
4047

48+
bucket: str
49+
key: str
50+
bucket, key = SageMaker._parse_path(path)
4151
body = self._client_s3.get_object(Bucket=bucket, Key=key)["Body"].read()
4252
body = tarfile.io.BytesIO(body) # type: ignore
4353
tar = tarfile.open(fileobj=body)
4454

45-
results = []
46-
for member in tar.getmembers():
55+
members = tar.getmembers()
56+
if len(members) < 1:
57+
raise InvalidSagemakerOutput(f"No artifacts found in {path}")
58+
59+
results: Dict[str, Any] = {}
60+
for member in members:
61+
logger.debug(f"member: {member.name}")
4762
f = tar.extractfile(member)
48-
file_type = member.name.split(".")[-1]
63+
file_type: str = member.name.split(".")[-1]
4964

5065
if (file_type == "pkl") and (f is not None):
5166
f = pickle.load(f)
5267

53-
results.append(f)
68+
results[member.name] = f
5469

5570
return results
71+
72+
def get_model(self, job_name: str = None, path: str = None, model_name: str = None) -> Any:
73+
"""
74+
Extract and deserialize a Sagemaker's output model (.tat.gz)
75+
76+
:param job_name: Sagemaker's job name
77+
:param path: S3 path (model.tar.gz path)
78+
:param model_name: model name (e.g: )
79+
:return:
80+
"""
81+
outputs: Dict[str, Any] = self.get_job_outputs(job_name=job_name, path=path)
82+
outputs_len: int = len(outputs)
83+
if model_name in outputs:
84+
return outputs[model_name]
85+
elif outputs_len > 1:
86+
raise InvalidSagemakerOutput(
87+
f"Number of artifacts found: {outputs_len}. Please, specify a model_name or use the Sagemaker.get_job_outputs() method."
88+
)
89+
return list(outputs.values())[0]

docs/source/examples.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,15 @@ Athena query to receive the result as python primitives (Iterable[Dict[str, Any]
370370
for row in wr.athena.query(query="...", database="..."):
371371
print(row)
372372
373+
Load and unzip SageMaker model
374+
``````````````````````````````
375+
376+
.. code-block:: python
377+
378+
import awswrangler as wr
379+
380+
outputs = wr.sagemaker.get_model("JOB_NAME")
381+
373382
Load and unzip SageMaker job output
374383
```````````````````````````````````
375384

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ General
5454
* Get EMR step state
5555
* Athena query to receive the result as python primitives (*Iterable[Dict[str, Any]*)
5656
* Load and Unzip SageMaker jobs outputs
57+
* Load and Unzip SageMaker models
5758
* Redshift -> Parquet (S3)
5859
* Aurora -> CSV (S3) (MySQL) (NEW :star:)
5960

testing/test_awswrangler/test_sagemaker.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import boto3
77
import pytest
88

9+
import awswrangler as wr
910
from awswrangler import Session
11+
from awswrangler.exceptions import InvalidSagemakerOutput
1012
from sklearn.linear_model import LinearRegression
1113

1214
logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s")
@@ -54,18 +56,89 @@ def model(bucket):
5456

5557
yield f"s3://{bucket}/{model_path}"
5658

57-
os.remove("model.pkl")
58-
os.remove("model.tar.gz")
59+
try:
60+
os.remove("model.pkl")
61+
except OSError:
62+
pass
63+
try:
64+
os.remove("model.tar.gz")
65+
except OSError:
66+
pass
67+
68+
69+
@pytest.fixture(scope="module")
70+
def model_empty(bucket):
71+
model_path = "output_empty/model.tar.gz"
72+
73+
with tarfile.open("model.tar.gz", "w:gz") as tar:
74+
pass
75+
76+
s3 = boto3.resource("s3")
77+
s3.Bucket(bucket).upload_file("model.tar.gz", model_path)
78+
79+
yield f"s3://{bucket}/{model_path}"
80+
81+
try:
82+
os.remove("model.tar.gz")
83+
except OSError:
84+
pass
85+
86+
87+
@pytest.fixture(scope="module")
88+
def model_double(bucket):
89+
model_path = "output_double/model.tar.gz"
90+
91+
lr = LinearRegression()
92+
with open("model.pkl", "wb") as fp:
93+
pickle.dump(lr, fp, pickle.HIGHEST_PROTOCOL)
94+
95+
with open("model2.pkl", "wb") as fp:
96+
pickle.dump(lr, fp, pickle.HIGHEST_PROTOCOL)
97+
98+
with tarfile.open("model.tar.gz", "w:gz") as tar:
99+
tar.add("model.pkl")
100+
tar.add("model2.pkl")
101+
102+
s3 = boto3.resource("s3")
103+
s3.Bucket(bucket).upload_file("model.tar.gz", model_path)
104+
105+
yield f"s3://{bucket}/{model_path}"
106+
107+
try:
108+
os.remove("model.pkl")
109+
except OSError:
110+
pass
111+
try:
112+
os.remove("model2.pkl")
113+
except OSError:
114+
pass
115+
try:
116+
os.remove("model.tar.gz")
117+
except OSError:
118+
pass
59119

60120

61121
def test_get_job_outputs_by_path(session, model):
62122
outputs = session.sagemaker.get_job_outputs(path=model)
63-
assert type(outputs[0]) == LinearRegression
123+
assert type(list(outputs.values())[0]) == LinearRegression
64124

65125

66126
def test_get_job_outputs_by_job_id(session, bucket):
67127
pass
68128

69129

70-
def test_get_job_outputs_empty(session, bucket):
71-
pass
130+
def test_get_model_empty(model_empty):
131+
with pytest.raises(InvalidSagemakerOutput):
132+
wr.sagemaker.get_model(path=model_empty)
133+
134+
135+
def test_get_model_double(session, model_double):
136+
with pytest.raises(InvalidSagemakerOutput):
137+
wr.sagemaker.get_model(path=model_double)
138+
model = session.sagemaker.get_model(path=model_double, model_name="model.pkl")
139+
assert type(model) == LinearRegression
140+
141+
142+
def test_get_model_by_path(session, model):
143+
model = session.sagemaker.get_model(path=model)
144+
assert type(model) == LinearRegression

0 commit comments

Comments
 (0)