1- from typing import Any
1+ from typing import Any , Dict
22import pickle
33import tarfile
44import logging
55
6- from awswrangler .exceptions import InvalidParameters
6+ from awswrangler .exceptions import InvalidParameters , InvalidSagemakerOutput
77
88logger = logging .getLogger (__name__ )
99
@@ -22,34 +22,68 @@ def _parse_path(path):
2222 parts = path2 .partition ("/" )
2323 return parts [0 ], parts [2 ]
2424
25- def get_job_outputs (self , job_name : str = None , path : str = None ) -> Any :
25+ def get_job_outputs (self , job_name : str = None , path : str = None ) -> Dict [str , Any ]:
26+ """
27+ Extract and deserialize all Sagemaker's outputs (everything inside model.tar.gz)
28+
29+ :param job_name: Sagemaker's job name
30+ :param path: S3 path (model.tar.gz path)
31+ :return: A Dictionary with all filenames (key) and all objects (values)
32+ """
2633
2734 if path and job_name :
28- raise InvalidParameters ("Specify either path, job_arn or job_name" )
35+ raise InvalidParameters ("Specify either path or job_name" )
2936
3037 if job_name :
3138 path = self ._client_sagemaker .describe_training_job (
3239 TrainingJobName = job_name )["ModelArtifacts" ]["S3ModelArtifacts" ]
3340
34- if not self ._session .s3 .does_object_exists (path ):
35- return None
41+ if path is not None :
42+ if path .split ("/" )[- 1 ] != "model.tar.gz" :
43+ path = f"{ path } /model.tar.gz"
3644
37- bucket , key = SageMaker ._parse_path (path )
38- if key .split ("/" )[- 1 ] != "model.tar.gz" :
39- key = f"{ key } /model.tar.gz"
45+ if self ._session .s3 .does_object_exists (path ) is False :
46+ raise InvalidSagemakerOutput (f"Path does not exists ({ path } )" )
4047
48+ bucket : str
49+ key : str
50+ bucket , key = SageMaker ._parse_path (path )
4151 body = self ._client_s3 .get_object (Bucket = bucket , Key = key )["Body" ].read ()
4252 body = tarfile .io .BytesIO (body ) # type: ignore
4353 tar = tarfile .open (fileobj = body )
4454
45- results = []
46- for member in tar .getmembers ():
55+ members = tar .getmembers ()
56+ if len (members ) < 1 :
57+ raise InvalidSagemakerOutput (f"No artifacts found in { path } " )
58+
59+ results : Dict [str , Any ] = {}
60+ for member in members :
61+ logger .debug (f"member: { member .name } " )
4762 f = tar .extractfile (member )
48- file_type = member .name .split ("." )[- 1 ]
63+ file_type : str = member .name .split ("." )[- 1 ]
4964
5065 if (file_type == "pkl" ) and (f is not None ):
5166 f = pickle .load (f )
5267
53- results . append ( f )
68+ results [ member . name ] = f
5469
5570 return results
71+
72+ def get_model (self , job_name : str = None , path : str = None , model_name : str = None ) -> Any :
73+ """
74+ Extract and deserialize a Sagemaker's output model (.tat.gz)
75+
76+ :param job_name: Sagemaker's job name
77+ :param path: S3 path (model.tar.gz path)
78+ :param model_name: model name (e.g: )
79+ :return:
80+ """
81+ outputs : Dict [str , Any ] = self .get_job_outputs (job_name = job_name , path = path )
82+ outputs_len : int = len (outputs )
83+ if model_name in outputs :
84+ return outputs [model_name ]
85+ elif outputs_len > 1 :
86+ raise InvalidSagemakerOutput (
87+ f"Number of artifacts found: { outputs_len } . Please, specify a model_name or use the Sagemaker.get_job_outputs() method."
88+ )
89+ return list (outputs .values ())[0 ]
0 commit comments