Skip to content

Commit f2de54c

Browse files
authored
Merge pull request #84 from luigift/sagemaker
Sagemaker
2 parents 3f580d8 + 84763fb commit f2de54c

File tree

6 files changed

+213
-52
lines changed

6 files changed

+213
-52
lines changed

README.md

Lines changed: 75 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
## Use Cases
2424

2525
### Pandas
26+
2627
* Pandas -> Parquet (S3) (Parallel)
2728
* Pandas -> CSV (S3) (Parallel)
2829
* Pandas -> Glue Catalog Table
@@ -38,11 +39,13 @@
3839
* Encrypt Pandas Dataframes on S3 with KMS keys
3940

4041
### PySpark
42+
4143
* PySpark -> Redshift (Parallel)
4244
* Register Glue table from Dataframe stored on S3
4345
* Flatten nested DataFrames
4446

4547
### General
48+
4649
* List S3 objects (Parallel)
4750
* Delete S3 objects (Parallel)
4851
* Delete listed S3 objects (Parallel)
@@ -78,8 +81,9 @@ Runs anywhere (AWS Lambda, AWS Glue Python Shell, EMR, EC2, on-premises, local,
7881
#### Writing Pandas Dataframe to S3 + Glue Catalog
7982

8083
```py3
81-
wrangler = awswrangler.Session()
82-
wrangler.pandas.to_parquet(
84+
import awswrangler as wr
85+
86+
wr.pandas.to_parquet(
8387
dataframe=dataframe,
8488
database="database",
8589
path="s3://...",
@@ -92,21 +96,24 @@ If a Glue Database name is passed, all the metadata will be created in the Glue
9296
#### Writing Pandas Dataframe to S3 as Parquet encrypting with a KMS key
9397

9498
```py3
99+
import awswrangler as wr
100+
95101
extra_args = {
96102
"ServerSideEncryption": "aws:kms",
97103
"SSEKMSKeyId": "YOUR_KMY_KEY_ARN"
98104
}
99-
wrangler = awswrangler.Session(s3_additional_kwargs=extra_args)
100-
wrangler.pandas.to_parquet(
105+
sess = wr.Session(s3_additional_kwargs=extra_args)
106+
sess.pandas.to_parquet(
101107
path="s3://..."
102108
)
103109
```
104110

105111
#### Reading from AWS Athena to Pandas
106112

107113
```py3
108-
wrangler = awswrangler.Session()
109-
dataframe = wrangler.pandas.read_sql_athena(
114+
import awswrangler as wr
115+
116+
dataframe = wr.pandas.read_sql_athena(
110117
sql="select * from table",
111118
database="database"
112119
)
@@ -115,21 +122,25 @@ dataframe = wrangler.pandas.read_sql_athena(
115122
#### Reading from AWS Athena to Pandas in chunks (For memory restrictions)
116123

117124
```py3
118-
wrangler = awswrangler.Session()
119-
dataframe_iter = wrangler.pandas.read_sql_athena(
125+
import awswrangler as wr
126+
127+
df_iter = wr.pandas.read_sql_athena(
120128
sql="select * from table",
121129
database="database",
122130
max_result_size=512_000_000 # 512 MB
123131
)
124-
for dataframe in dataframe_iter:
125-
print(dataframe) # Do whatever you want
132+
133+
for df in df_iter:
134+
print(df) # Do whatever you want
126135
```
127136

128137
#### Reading from AWS Athena to Pandas with the blazing fast CTAS approach
129138

130139
```py3
131-
wrangler = awswrangler.Session(athena_ctas_approach=True)
132-
dataframe = wrangler.pandas.read_sql_athena(
140+
import awswrangler as wr
141+
142+
sess = wr.Session(athena_ctas_approach=True)
143+
dataframe = sess.pandas.read_sql_athena(
133144
sql="select * from table",
134145
database="database"
135146
)
@@ -138,27 +149,31 @@ dataframe = wrangler.pandas.read_sql_athena(
138149
#### Reading from S3 (CSV) to Pandas
139150

140151
```py3
141-
wrangler = awswrangler.Session()
142-
dataframe = wrangler.pandas.read_csv(path="s3://...")
152+
import awswrangler as wr
153+
154+
dataframe = wr.pandas.read_csv(path="s3://...")
143155
```
144156

145157
#### Reading from S3 (CSV) to Pandas in chunks (For memory restrictions)
146158

147159
```py3
148-
wrangler = awswrangler.Session()
149-
dataframe_iter = wrangler.pandas.read_csv(
160+
import awswrangler as wr
161+
162+
df_iter = wr.pandas.read_csv(
150163
path="s3://...",
151164
max_result_size=512_000_000 # 512 MB
152165
)
153-
for dataframe in dataframe_iter:
154-
print(dataframe) # Do whatever you want
166+
167+
for df in df_iter:
168+
print(df) # Do whatever you want
155169
```
156170

157171
#### Reading from CloudWatch Logs Insights to Pandas
158172

159173
```py3
160-
wrangler = awswrangler.Session()
161-
dataframe = wrangler.pandas.read_log_query(
174+
import awswrangler as wr
175+
176+
dataframe = wr.pandas.read_log_query(
162177
log_group_names=[LOG_GROUP_NAME],
163178
query="fields @timestamp, @message | sort @timestamp desc | limit 5",
164179
)
@@ -168,14 +183,13 @@ dataframe = wrangler.pandas.read_log_query(
168183

169184
```py3
170185
import pandas
171-
import awswrangler
186+
import awswrangler as wr
172187

173188
df = pandas.read_... # Read from anywhere
174189

175190
# Typical Pandas, Numpy or Pyarrow transformation HERE!
176191

177-
wrangler = awswrangler.Session()
178-
wrangler.pandas.to_parquet( # Storing the data and metadata to Data Lake
192+
wr.pandas.to_parquet( # Storing the data and metadata to Data Lake
179193
dataframe=dataframe,
180194
database="database",
181195
path="s3://...",
@@ -186,8 +200,9 @@ wrangler.pandas.to_parquet( # Storing the data and metadata to Data Lake
186200
#### Loading Pandas Dataframe to Redshift
187201

188202
```py3
189-
wrangler = awswrangler.Session()
190-
wrangler.pandas.to_redshift(
203+
import awswrangler as wr
204+
205+
wr.pandas.to_redshift(
191206
dataframe=dataframe,
192207
path="s3://temp_path",
193208
schema="...",
@@ -202,8 +217,9 @@ wrangler.pandas.to_redshift(
202217
#### Extract Redshift query to Pandas DataFrame
203218

204219
```py3
205-
wrangler = awswrangler.Session()
206-
dataframe = session.pandas.read_sql_redshift(
220+
import awswrangler as wr
221+
222+
dataframe = wr.pandas.read_sql_redshift(
207223
sql="SELECT ...",
208224
iam_role="YOUR_ROLE_ARN",
209225
connection=con,
@@ -215,8 +231,9 @@ dataframe = session.pandas.read_sql_redshift(
215231
#### Loading PySpark Dataframe to Redshift
216232

217233
```py3
218-
wrangler = awswrangler.Session(spark_session=spark)
219-
wrangler.spark.to_redshift(
234+
import awswrangler as wr
235+
236+
wr.spark.to_redshift(
220237
dataframe=df,
221238
path="s3://...",
222239
connection=conn,
@@ -230,13 +247,15 @@ wrangler.spark.to_redshift(
230247
#### Register Glue table from Dataframe stored on S3
231248

232249
```py3
250+
import awswrangler as wr
251+
233252
dataframe.write \
234253
.mode("overwrite") \
235254
.format("parquet") \
236255
.partitionBy(["year", "month"]) \
237256
.save(compression="gzip", path="s3://...")
238-
wrangler = awswrangler.Session(spark_session=spark)
239-
wrangler.spark.create_glue_table(
257+
sess = wr.Session(spark_session=spark)
258+
sess.spark.create_glue_table(
240259
dataframe=dataframe,
241260
file_format="parquet",
242261
partition_by=["year", "month"],
@@ -248,8 +267,9 @@ wrangler.spark.create_glue_table(
248267
#### Flatten nested PySpark DataFrame
249268

250269
```py3
251-
wrangler = awswrangler.Session(spark_session=spark)
252-
dfs = wrangler.spark.flatten(dataframe=df_nested)
270+
import awswrangler as wr
271+
sess = awswrangler.Session(spark_session=spark)
272+
dfs = sess.spark.flatten(dataframe=df_nested)
253273
for name, df_flat in dfs.items():
254274
print(name)
255275
df_flat.show()
@@ -260,15 +280,17 @@ for name, df_flat in dfs.items():
260280
#### Deleting a bunch of S3 objects (parallel)
261281

262282
```py3
263-
wrangler = awswrangler.Session()
264-
wrangler.s3.delete_objects(path="s3://...")
283+
import awswrangler as wr
284+
285+
wr.s3.delete_objects(path="s3://...")
265286
```
266287

267288
#### Get CloudWatch Logs Insights query results
268289

269290
```py3
270-
wrangler = awswrangler.Session()
271-
results = wrangler.cloudwatchlogs.query(
291+
import awswrangler as wr
292+
293+
results = wr.cloudwatchlogs.query(
272294
log_group_names=[LOG_GROUP_NAME],
273295
query="fields @timestamp, @message | sort @timestamp desc | limit 5",
274296
)
@@ -277,15 +299,17 @@ results = wrangler.cloudwatchlogs.query(
277299
#### Load partitions on Athena/Glue table (repair table)
278300

279301
```py3
280-
wrangler = awswrangler.Session()
281-
wrangler.athena.repair_table(database="db_name", table="tbl_name")
302+
import awswrangler as wr
303+
304+
wr.athena.repair_table(database="db_name", table="tbl_name")
282305
```
283306

284307
#### Create EMR cluster
285308

286309
```py3
287-
wrangler = awswrangler.Session()
288-
cluster_id = wrangler.emr.create_cluster(
310+
import awswrangler as wr
311+
312+
cluster_id = wr.emr.create_cluster(
289313
cluster_name="wrangler_cluster",
290314
logging_s3_path=f"s3://BUCKET_NAME/emr-logs/",
291315
emr_release="emr-5.27.0",
@@ -337,28 +361,28 @@ print(cluster_id)
337361
#### Athena query to receive the result as python primitives (*Iterable[Dict[str, Any]*)
338362

339363
```py3
340-
wrangler = awswrangler.Session()
341-
for row in wrangler.athena.query(query="...", database="..."):
364+
import awswrangler as wr
365+
366+
for row in wr.athena.query(query="...", database="..."):
342367
print(row)
343368
```
344369

345370
## Diving Deep
346371

347-
348372
### Parallelism, Non-picklable objects and GeoPandas
349373

350374
AWS Data Wrangler tries to parallelize everything that is possible (I/O and CPU bound task).
351375
You can control the parallelism level using the parameters:
352376

353-
- **procs_cpu_bound**: number of processes that can be used in single node applications for CPU bound case (Default: os.cpu_count())
354-
- **procs_io_bound**: number of processes that can be used in single node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR)
377+
* **procs_cpu_bound**: number of processes that can be used in single node applications for CPU bound case (Default: os.cpu_count())
378+
* **procs_io_bound**: number of processes that can be used in single node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR)
355379

356380
Both can be defined on Session level or directly in the functions.
357381

358382
Some special cases will not work with parallelism:
359383

360-
- GeoPandas
361-
- Columns with non-picklable objects
384+
* GeoPandas
385+
* Columns with non-picklable objects
362386

363387
To handle that use `procs_cpu_bound=1` and avoid the distribution of the dataframe.
364388

@@ -370,16 +394,16 @@ We can handle this object column fine inferring the types of theses objects insi
370394
To work with null object columns you can explicitly set the expected Athena data type for the target table doing:
371395

372396
```py3
373-
import awswrangler
397+
import awswrangler as wr
374398
import pandas as pd
375399

376400
dataframe = pd.DataFrame({
377401
"col": [1, 2],
378402
"col_string_null": [None, None],
379403
"col_date_null": [None, None],
380404
})
381-
session = awswrangler.Session()
382-
session.pandas.to_parquet(
405+
406+
wr.pandas.to_parquet(
383407
dataframe=dataframe,
384408
database="DATABASE",
385409
path=f"s3://...",

awswrangler/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,39 @@
1010
from awswrangler.glue import Glue # noqa
1111
from awswrangler.redshift import Redshift # noqa
1212
from awswrangler.emr import EMR # noqa
13+
from awswrangler.sagemaker import SageMaker # noqa
1314
import awswrangler.utils # noqa
1415
import awswrangler.data_types # noqa
1516

17+
18+
class DynamicInstantiate:
19+
20+
__default_session = Session()
21+
22+
def __init__(self, service):
23+
self._service = service
24+
25+
def __getattr__(self, name):
26+
return getattr(
27+
getattr(
28+
DynamicInstantiate.__default_session,
29+
self._service
30+
),
31+
name
32+
)
33+
34+
1635
if importlib.util.find_spec("pyspark"): # type: ignore
1736
from awswrangler.spark import Spark # noqa
1837

38+
s3 = DynamicInstantiate("s3")
39+
emr = DynamicInstantiate("emr")
40+
glue = DynamicInstantiate("glue")
41+
spark = DynamicInstantiate("spark")
42+
pandas = DynamicInstantiate("pandas")
43+
athena = DynamicInstantiate("athena")
44+
redshift = DynamicInstantiate("redshift")
45+
sagemaker = DynamicInstantiate("sagemaker")
46+
cloudwatchlogs = DynamicInstantiate("cloudwatchlogs")
47+
1948
logging.getLogger("awswrangler").addHandler(logging.NullHandler())

awswrangler/sagemaker.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pickle
2+
import tarfile
3+
import logging
4+
from typing import Any
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
class SageMaker:
10+
def __init__(self, session):
11+
self._session = session
12+
self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config)
13+
14+
@staticmethod
15+
def _parse_path(path):
16+
path2 = path.replace("s3://", "")
17+
parts = path2.partition("/")
18+
return parts[0], parts[2]
19+
20+
def get_job_outputs(self, path: str) -> Any:
21+
22+
bucket, key = SageMaker._parse_path(path)
23+
if key.split("/")[-1] != "model.tar.gz":
24+
key = f"{key}/model.tar.gz"
25+
body = self._client_s3.get_object(Bucket=bucket, Key=key)["Body"].read()
26+
body = tarfile.io.BytesIO(body)
27+
tar = tarfile.open(fileobj=body)
28+
29+
results = []
30+
for member in tar.getmembers():
31+
f = tar.extractfile(member)
32+
file_type = member.name.split(".")[-1]
33+
34+
if file_type == "pkl":
35+
f = pickle.load(f)
36+
37+
results.append(f)
38+
39+
return results

0 commit comments

Comments
 (0)