Skip to content

Commit affdfbb

Browse files
authored
Merge pull request #85 from luigift/sagemaker
Sagemaker output from job name
2 parents c656474 + 25b79ab commit affdfbb

File tree

5 files changed

+46
-9
lines changed

5 files changed

+46
-9
lines changed

awswrangler/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
class DynamicInstantiate:
1919

20-
__default_session = Session()
20+
__default_session = None
2121

2222
def __init__(self, service):
2323
self._service = service
2424

2525
def __getattr__(self, name):
26+
if DynamicInstantiate.__default_session is None:
27+
DynamicInstantiate.__default_session = Session()
2628
return getattr(getattr(DynamicInstantiate.__default_session, self._service), name)
2729

2830

awswrangler/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,7 @@ class InvalidTable(Exception):
9292

9393
class InvalidParameters(Exception):
9494
pass
95+
96+
97+
class AWSCredentialsNotFound(Exception):
98+
pass

awswrangler/sagemaker.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import pickle
22
import tarfile
33
import logging
4+
45
from typing import Any
6+
from awswrangler.exceptions import InvalidParameters
57

68
logger = logging.getLogger(__name__)
79

@@ -10,18 +12,29 @@ class SageMaker:
1012
def __init__(self, session):
1113
self._session = session
1214
self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config)
15+
self._client_sagemaker = session.boto3_session.client(service_name="sagemaker")
1316

1417
@staticmethod
1518
def _parse_path(path):
1619
path2 = path.replace("s3://", "")
1720
parts = path2.partition("/")
1821
return parts[0], parts[2]
1922

20-
def get_job_outputs(self, path: str) -> Any:
23+
def get_job_outputs(self, job_name: str = None, path: str = None) -> Any:
24+
25+
if path and job_name:
26+
raise InvalidParameters("Specify either path, job_arn or job_name")
27+
28+
if job_name:
29+
path = self._client_sagemaker.describe_training_job(TrainingJobName=job_name)["ModelArtifacts"]["S3ModelArtifacts"]
30+
31+
if not self._session.s3.does_object_exists(path):
32+
return None
2133

2234
bucket, key = SageMaker._parse_path(path)
2335
if key.split("/")[-1] != "model.tar.gz":
2436
key = f"{key}/model.tar.gz"
37+
2538
body = self._client_s3.get_object(Bucket=bucket, Key=key)["Body"].read()
2639
body = tarfile.io.BytesIO(body) # type: ignore
2740
tar = tarfile.open(fileobj=body)

awswrangler/session.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from awswrangler.redshift import Redshift
1515
from awswrangler.emr import EMR
1616
from awswrangler.sagemaker import SageMaker
17+
from awswrangler.exceptions import AWSCredentialsNotFound
1718

1819
PYSPARK_INSTALLED = False
1920
if importlib.util.find_spec("pyspark"): # type: ignore
@@ -77,6 +78,7 @@ def __init__(self,
7778
:param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
7879
:param redshift_temp_s3_path: redshift_temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...)
7980
"""
81+
8082
self._profile_name: Optional[str] = (boto3_session.profile_name if boto3_session else profile_name)
8183
self._aws_access_key_id: Optional[str] = (boto3_session.get_credentials().access_key
8284
if boto3_session else aws_access_key_id)
@@ -130,8 +132,11 @@ def _load_new_boto3_session(self):
130132
args["aws_secret_access_key"] = self.aws_secret_access_key
131133
self._boto3_session = boto3.Session(**args)
132134
self._profile_name = self._boto3_session.profile_name
133-
self._aws_access_key_id = self._boto3_session.get_credentials().access_key
134-
self._aws_secret_access_key = self._boto3_session.get_credentials().secret_key
135+
credentials = self._boto3_session.get_credentials()
136+
if credentials is None:
137+
raise AWSCredentialsNotFound("Please run aws configure: https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html")
138+
self._aws_access_key_id = credentials.access_key
139+
self._aws_secret_access_key = credentials.secret_key
135140
self._region_name = self._boto3_session.region_name
136141

137142
def _load_new_primitives(self):

testing/test_awswrangler/test_sagemaker.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def bucket(session, cloudformation_outputs):
3838
session.s3.delete_objects(path=f"s3://{bucket}/")
3939

4040

41-
def test_get_job_outputs(session, bucket):
42-
model_path = "output"
43-
s3 = boto3.resource("s3")
41+
@pytest.fixture(scope="module")
42+
def model(bucket):
43+
model_path = "output/model.tar.gz"
4444

4545
lr = LinearRegression()
4646
with open("model.pkl", "wb") as fp:
@@ -49,10 +49,23 @@ def test_get_job_outputs(session, bucket):
4949
with tarfile.open("model.tar.gz", "w:gz") as tar:
5050
tar.add("model.pkl")
5151

52-
s3.Bucket(bucket).upload_file("model.tar.gz", f"{model_path}/model.tar.gz")
53-
outputs = session.sagemaker.get_job_outputs(f"{bucket}/{model_path}")
52+
s3 = boto3.resource("s3")
53+
s3.Bucket(bucket).upload_file("model.tar.gz", model_path)
54+
55+
yield f"s3://{bucket}/{model_path}"
5456

5557
os.remove("model.pkl")
5658
os.remove("model.tar.gz")
5759

60+
61+
def test_get_job_outputs_by_path(session, model):
62+
outputs = session.sagemaker.get_job_outputs(path=model)
5863
assert type(outputs[0]) == LinearRegression
64+
65+
66+
def test_get_job_outputs_by_job_id(session, bucket):
67+
pass
68+
69+
70+
def test_get_job_outputs_empty(session, bucket):
71+
pass

0 commit comments

Comments
 (0)