Skip to content

Commit be0d89e

Browse files
authored
Merge pull request #210 from awslabs/pytorch
Pytorch Module
2 parents d9f107a + 910e3b6 commit be0d89e

File tree

23 files changed

+1229
-110
lines changed

23 files changed

+1229
-110
lines changed

.github/workflows/static-checking.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,12 @@ jobs:
2424
uses: actions/setup-python@v1
2525
with:
2626
python-version: ${{ matrix.python-version }}
27-
- name: Install dependencies
28-
run: |
29-
python -m pip install --upgrade pip
30-
pip install -r requirements.txt
31-
pip install -r requirements-dev.txt
27+
- name: Setup Environment
28+
run: ./setup-dev-env.sh
3229
- name: CloudFormation Lint
3330
run: cfn-lint -t testing/cloudformation.yaml
3431
- name: Documentation Lint
35-
run: pydocstyle awswrangler/ --add-ignore=D204
32+
run: pydocstyle awswrangler/ --add-ignore=D204,D403
3633
- name: mypy check
3734
run: mypy awswrangler
3835
- name: Flake8 Lint

.pylintrc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ disable=print-statement,
141141
comprehension-escape,
142142
C0330,
143143
C0103,
144-
W1202
144+
W1202,
145+
too-few-public-methods
145146

146147
# Enable the message, report, category or checker with the given id(s). You can
147148
# either give multiple identifier separated by comma (,) or put this option

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ df = wr.db.read_sql_query("SELECT * FROM external_schema.my_table", con=engine)
8484
- [11 - CSV Datasets](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/11%20-%20CSV%20Datasets.ipynb)
8585
- [12 - CSV Crawler](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/12%20-%20CSV%20Crawler.ipynb)
8686
- [13 - Merging Datasets on S3](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/13%20-%20Merging%20Datasets%20on%20S3.ipynb)
87+
- [14 - PyTorch](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/14%20-%20PyTorch.ipynb)
8788
- [15 - EMR](https://github.com/awslabs/aws-data-wrangler/blob/dev/tutorials/15%20-%20EMR.ipynb)
8889
- [16 - EMR & Docker](https://github.com/awslabs/aws-data-wrangler/blob/dev/tutorials/16%20-%20EMR%20%26%20Docker.ipynb)
8990
- [**API Reference**](https://aws-data-wrangler.readthedocs.io/en/latest/api.html)

awswrangler/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66
"""
77

88
import logging
9+
from importlib.util import find_spec
910

1011
from awswrangler import athena, catalog, cloudwatch, db, emr, exceptions, s3 # noqa
1112
from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa
1213
from awswrangler._utils import get_account_id # noqa
1314

15+
if find_spec("torch") and find_spec("torchvision") and find_spec("torchaudio") and find_spec("PIL"):
16+
from awswrangler import torch # noqa
17+
1418
logging.getLogger("awswrangler").addHandler(logging.NullHandler())

awswrangler/_data_types.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta
207207
return sqlalchemy.types.Date
208208
if pa.types.is_binary(dtype):
209209
if db_type == "redshift":
210-
raise exceptions.UnsupportedType(f"Binary columns are not supported for Redshift.") # pragma: no cover
210+
raise exceptions.UnsupportedType("Binary columns are not supported for Redshift.") # pragma: no cover
211211
return sqlalchemy.types.Binary
212212
if pa.types.is_decimal(dtype):
213213
return sqlalchemy.types.Numeric(precision=dtype.precision, scale=dtype.scale)
@@ -257,7 +257,7 @@ def pyarrow_types_from_pandas(
257257
# Filling schema
258258
columns_types: Dict[str, pa.DataType]
259259
columns_types = {n: cols_dtypes[n] for n in sorted_cols}
260-
_logger.debug(f"columns_types: {columns_types}")
260+
_logger.debug("columns_types: %s", columns_types)
261261
return columns_types
262262

263263

@@ -275,7 +275,7 @@ def athena_types_from_pandas(
275275
athena_columns_types[k] = casts[k]
276276
else:
277277
athena_columns_types[k] = pyarrow2athena(dtype=v)
278-
_logger.debug(f"athena_columns_types: {athena_columns_types}")
278+
_logger.debug("athena_columns_types: %s", athena_columns_types)
279279
return athena_columns_types
280280

281281

@@ -315,7 +315,7 @@ def pyarrow_schema_from_pandas(
315315
if (k in df.columns) and (k not in ignore):
316316
columns_types[k] = athena2pyarrow(v)
317317
columns_types = {k: v for k, v in columns_types.items() if v is not None}
318-
_logger.debug(f"columns_types: {columns_types}")
318+
_logger.debug("columns_types: %s", columns_types)
319319
return pa.schema(fields=columns_types)
320320

321321

@@ -324,11 +324,11 @@ def athena_types_from_pyarrow_schema(
324324
) -> Tuple[Dict[str, str], Optional[Dict[str, str]]]:
325325
"""Extract the related Athena data types from any PyArrow Schema considering possible partitions."""
326326
columns_types: Dict[str, str] = {str(f.name): pyarrow2athena(dtype=f.type) for f in schema}
327-
_logger.debug(f"columns_types: {columns_types}")
327+
_logger.debug("columns_types: %s", columns_types)
328328
partitions_types: Optional[Dict[str, str]] = None
329329
if partitions is not None:
330330
partitions_types = {p.name: pyarrow2athena(p.dictionary.type) for p in partitions}
331-
_logger.debug(f"partitions_types: {partitions_types}")
331+
_logger.debug("partitions_types: %s", partitions_types)
332332
return columns_types, partitions_types
333333

334334

@@ -382,5 +382,5 @@ def sqlalchemy_types_from_pandas(
382382
sqlalchemy_columns_types[k] = casts[k]
383383
else:
384384
sqlalchemy_columns_types[k] = pyarrow2sqlalchemy(dtype=v, db_type=db_type)
385-
_logger.debug(f"sqlalchemy_columns_types: {sqlalchemy_columns_types}")
385+
_logger.debug("sqlalchemy_columns_types: %s", sqlalchemy_columns_types)
386386
return sqlalchemy_columns_types

awswrangler/athena.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def wait_query(query_execution_id: str, boto3_session: Optional[boto3.Session] =
176176
time.sleep(_QUERY_WAIT_POLLING_DELAY)
177177
response = client_athena.get_query_execution(QueryExecutionId=query_execution_id)
178178
state = response["QueryExecution"]["Status"]["State"]
179-
_logger.debug(f"state: {state}")
180-
_logger.debug(f"StateChangeReason: {response['QueryExecution']['Status'].get('StateChangeReason')}")
179+
_logger.debug("state: %s", state)
180+
_logger.debug("StateChangeReason: %s", response["QueryExecution"]["Status"].get("StateChangeReason"))
181181
if state == "FAILED":
182182
raise exceptions.QueryFailed(response["QueryExecution"]["Status"].get("StateChangeReason"))
183183
if state == "CANCELLED":
@@ -265,7 +265,7 @@ def _get_query_metadata(
265265
cols_types: Dict[str, str] = get_query_columns_types(
266266
query_execution_id=query_execution_id, boto3_session=boto3_session
267267
)
268-
_logger.debug(f"cols_types: {cols_types}")
268+
_logger.debug("cols_types: %s", cols_types)
269269
dtype: Dict[str, str] = {}
270270
parse_timestamps: List[str] = []
271271
parse_dates: List[str] = []
@@ -298,11 +298,11 @@ def _get_query_metadata(
298298
converters[col_name] = lambda x: Decimal(str(x)) if str(x) not in ("", "none", " ", "<NA>") else None
299299
else:
300300
dtype[col_name] = pandas_type
301-
_logger.debug(f"dtype: {dtype}")
302-
_logger.debug(f"parse_timestamps: {parse_timestamps}")
303-
_logger.debug(f"parse_dates: {parse_dates}")
304-
_logger.debug(f"converters: {converters}")
305-
_logger.debug(f"binaries: {binaries}")
301+
_logger.debug("dtype: %s", dtype)
302+
_logger.debug("parse_timestamps: %s", parse_timestamps)
303+
_logger.debug("parse_dates: %s", parse_dates)
304+
_logger.debug("converters: %s", converters)
305+
_logger.debug("binaries: %s", binaries)
306306
return dtype, parse_timestamps, parse_dates, converters, binaries
307307

308308

@@ -446,7 +446,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
446446
f") AS\n"
447447
f"{sql}"
448448
)
449-
_logger.debug(f"sql: {sql}")
449+
_logger.debug("sql: %s", sql)
450450
query_id: str = start_query_execution(
451451
sql=sql,
452452
database=database,
@@ -456,7 +456,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
456456
kms_key=kms_key,
457457
boto3_session=session,
458458
)
459-
_logger.debug(f"query_id: {query_id}")
459+
_logger.debug("query_id: %s", query_id)
460460
query_response: Dict[str, Any] = wait_query(query_execution_id=query_id, boto3_session=session)
461461
if query_response["QueryExecution"]["Status"]["State"] in ["FAILED", "CANCELLED"]: # pragma: no cover
462462
reason: str = query_response["QueryExecution"]["Status"]["StateChangeReason"]
@@ -468,7 +468,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
468468
manifest_path: str = f"{_s3_output}/tables/{query_id}-manifest.csv"
469469
paths: List[str] = _extract_ctas_manifest_paths(path=manifest_path, boto3_session=session)
470470
chunked: Union[bool, int] = False if chunksize is None else chunksize
471-
_logger.debug(f"chunked: {chunked}")
471+
_logger.debug("chunked: %s", chunked)
472472
if not paths:
473473
if chunked is False:
474474
dfs = pd.DataFrame()
@@ -485,9 +485,9 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
485485
)
486486
path = f"{_s3_output}/{query_id}.csv"
487487
s3.wait_objects_exist(paths=[path], use_threads=False, boto3_session=session)
488-
_logger.debug(f"Start CSV reading from {path}")
488+
_logger.debug("Start CSV reading from %s", path)
489489
_chunksize: Optional[int] = chunksize if isinstance(chunksize, int) else None
490-
_logger.debug(f"_chunksize: {_chunksize}")
490+
_logger.debug("_chunksize: %s", _chunksize)
491491
ret = s3.read_csv(
492492
path=[path],
493493
dtype=dtype,

awswrangler/catalog.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def drop_duplicated_columns(df: pd.DataFrame) -> pd.DataFrame:
766766
duplicated_cols = df.columns.duplicated()
767767
duplicated_cols_names: List[str] = list(df.columns[duplicated_cols])
768768
if len(duplicated_cols_names) > 0:
769-
_logger.warning(f"Dropping repeated columns: {duplicated_cols_names}")
769+
_logger.warning("Dropping repeated columns: %s", duplicated_cols_names)
770770
return df.loc[:, ~duplicated_cols]
771771

772772

@@ -967,11 +967,11 @@ def _create_table(
967967
if name in columns_comments:
968968
par["Comment"] = columns_comments[name]
969969
session: boto3.Session = _utils.ensure_session(session=boto3_session)
970-
971-
if mode == "overwrite":
970+
exist: bool = does_table_exist(database=database, table=table, boto3_session=session)
971+
if (mode == "overwrite") or (exist is False):
972972
delete_table_if_exists(database=database, table=table, boto3_session=session)
973-
client_glue: boto3.client = _utils.client(service_name="glue", session=session)
974-
client_glue.create_table(DatabaseName=database, TableInput=table_input)
973+
client_glue: boto3.client = _utils.client(service_name="glue", session=session)
974+
client_glue.create_table(DatabaseName=database, TableInput=table_input)
975975

976976

977977
def _csv_table_definition(

awswrangler/cloudwatch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def start_query(
5656
... )
5757
5858
"""
59-
_logger.debug(f"log_group_names: {log_group_names}")
59+
_logger.debug("log_group_names: %s", log_group_names)
6060
start_timestamp: int = int(1000 * start_time.timestamp())
6161
end_timestamp: int = int(1000 * end_time.timestamp())
62-
_logger.debug(f"start_timestamp: {start_timestamp}")
63-
_logger.debug(f"end_timestamp: {end_timestamp}")
62+
_logger.debug("start_timestamp: %s", start_timestamp)
63+
_logger.debug("end_timestamp: %s", end_timestamp)
6464
args: Dict[str, Any] = {
6565
"logGroupNames": log_group_names,
6666
"startTime": start_timestamp,
@@ -109,7 +109,7 @@ def wait_query(query_id: str, boto3_session: Optional[boto3.Session] = None) ->
109109
time.sleep(_QUERY_WAIT_POLLING_DELAY)
110110
response = client_logs.get_query_results(queryId=query_id)
111111
status = response["status"]
112-
_logger.debug(f"status: {status}")
112+
_logger.debug("status: %s", status)
113113
if status == "Failed": # pragma: no cover
114114
raise exceptions.QueryFailed(f"query ID: {query_id}")
115115
if status == "Cancelled":

awswrangler/db.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -155,29 +155,15 @@ def read_sql_query(
155155
... )
156156
157157
"""
158-
if not isinstance(con, sqlalchemy.engine.Engine): # pragma: no cover
159-
raise exceptions.InvalidConnection(
160-
"Invalid 'con' argument, please pass a "
161-
"SQLAlchemy Engine. Use wr.db.get_engine(), "
162-
"wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()"
163-
)
158+
_validate_engine(con=con)
164159
with con.connect() as _con:
165160
args = _convert_params(sql, params)
166161
cursor = _con.execute(*args)
167162
if chunksize is None:
168163
return _records2df(records=cursor.fetchall(), cols_names=cursor.keys(), index=index_col, dtype=dtype)
169-
return _iterate_cursor(cursor=cursor, chunksize=chunksize, index=index_col, dtype=dtype)
170-
171-
172-
def _iterate_cursor(
173-
cursor, chunksize: int, index: Optional[Union[str, List[str]]], dtype: Optional[Dict[str, pa.DataType]] = None
174-
) -> Iterator[pd.DataFrame]:
175-
while True:
176-
records = cursor.fetchmany(chunksize)
177-
if not records:
178-
break
179-
df: pd.DataFrame = _records2df(records=records, cols_names=cursor.keys(), index=index, dtype=dtype)
180-
yield df
164+
return _iterate_cursor(
165+
cursor=cursor, chunksize=chunksize, cols_names=cursor.keys(), index=index_col, dtype=dtype
166+
)
181167

182168

183169
def _records2df(
@@ -207,6 +193,20 @@ def _records2df(
207193
return df
208194

209195

196+
def _iterate_cursor(
197+
cursor: Any,
198+
chunksize: int,
199+
cols_names: List[str],
200+
index: Optional[Union[str, List[str]]],
201+
dtype: Optional[Dict[str, pa.DataType]] = None,
202+
) -> Iterator[pd.DataFrame]:
203+
while True:
204+
records = cursor.fetchmany(chunksize)
205+
if not records:
206+
break
207+
yield _records2df(records=records, cols_names=cols_names, index=index, dtype=dtype)
208+
209+
210210
def _convert_params(sql: str, params: Optional[Union[List, Tuple, Dict]]) -> List[Any]:
211211
args: List[Any] = [sql]
212212
if params is not None:
@@ -646,7 +646,7 @@ def copy_files_to_redshift( # pylint: disable=too-many-locals,too-many-argument
646646
athena_types, _ = s3.read_parquet_metadata(
647647
path=paths, dataset=False, use_threads=use_threads, boto3_session=session
648648
)
649-
_logger.debug(f"athena_types: {athena_types}")
649+
_logger.debug("athena_types: %s", athena_types)
650650
redshift_types: Dict[str, str] = {}
651651
for col_name, col_type in athena_types.items():
652652
length: int = _varchar_lengths[col_name] if col_name in _varchar_lengths else varchar_lengths_default
@@ -680,7 +680,7 @@ def copy_files_to_redshift( # pylint: disable=too-many-locals,too-many-argument
680680
def _rs_upsert(con: Any, table: str, temp_table: str, schema: str, primary_keys: Optional[List[str]] = None) -> None:
681681
if not primary_keys:
682682
primary_keys = _rs_get_primary_keys(con=con, schema=schema, table=table)
683-
_logger.debug(f"primary_keys: {primary_keys}")
683+
_logger.debug("primary_keys: %s", primary_keys)
684684
if not primary_keys: # pragma: no cover
685685
raise exceptions.InvalidRedshiftPrimaryKeys()
686686
equals_clause: str = f"{table}.%s = {temp_table}.%s"
@@ -735,7 +735,7 @@ def _rs_create_table(
735735
f"{distkey_str}"
736736
f"{sortkey_str}"
737737
)
738-
_logger.debug(f"Create table query:\n{sql}")
738+
_logger.debug("Create table query:\n%s", sql)
739739
con.execute(sql)
740740
return table, schema
741741

@@ -746,7 +746,7 @@ def _rs_validate_parameters(
746746
if diststyle not in _RS_DISTSTYLES:
747747
raise exceptions.InvalidRedshiftDiststyle(f"diststyle must be in {_RS_DISTSTYLES}")
748748
cols = list(redshift_types.keys())
749-
_logger.debug(f"Redshift columns: {cols}")
749+
_logger.debug("Redshift columns: %s", cols)
750750
if (diststyle == "KEY") and (not distkey):
751751
raise exceptions.InvalidRedshiftDistkey("You must pass a distkey if you intend to use KEY diststyle")
752752
if distkey and distkey not in cols:
@@ -775,13 +775,13 @@ def _rs_copy(
775775
sql: str = (
776776
f"COPY {table_name} FROM '{manifest_path}'\n" f"IAM_ROLE '{iam_role}'\n" "MANIFEST\n" "FORMAT AS PARQUET"
777777
)
778-
_logger.debug(f"copy query:\n{sql}")
778+
_logger.debug("copy query:\n%s", sql)
779779
con.execute(sql)
780780
sql = "SELECT pg_last_copy_id() AS query_id"
781781
query_id: int = con.execute(sql).fetchall()[0][0]
782782
sql = f"SELECT COUNT(DISTINCT filename) as num_files_loaded " f"FROM STL_LOAD_COMMITS WHERE query = {query_id}"
783783
num_files_loaded: int = con.execute(sql).fetchall()[0][0]
784-
_logger.debug(f"{num_files_loaded} files counted. {num_files} expected.")
784+
_logger.debug("%s files counted. %s expected.", num_files_loaded, num_files)
785785
if num_files_loaded != num_files: # pragma: no cover
786786
raise exceptions.RedshiftLoadError(
787787
f"Redshift load rollbacked. {num_files_loaded} files counted. {num_files} expected."
@@ -846,17 +846,17 @@ def write_redshift_copy_manifest(
846846
payload: str = json.dumps(manifest)
847847
bucket: str
848848
bucket, key = _utils.parse_path(manifest_path)
849-
_logger.debug(f"payload: {payload}")
849+
_logger.debug("payload: %s", payload)
850850
client_s3: boto3.client = _utils.client(service_name="s3", session=session)
851-
_logger.debug(f"bucket: {bucket}")
852-
_logger.debug(f"key: {key}")
851+
_logger.debug("bucket: %s", bucket)
852+
_logger.debug("key: %s", key)
853853
client_s3.put_object(Body=payload, Bucket=bucket, Key=key)
854854
return manifest
855855

856856

857857
def _rs_drop_table(con: Any, schema: str, table: str) -> None:
858858
sql = f"DROP TABLE IF EXISTS {schema}.{table}"
859-
_logger.debug(f"Drop table query:\n{sql}")
859+
_logger.debug("Drop table query:\n%s", sql)
860860
con.execute(sql)
861861

862862

@@ -1104,5 +1104,14 @@ def unload_redshift_to_files(
11041104
query_id: int = _con.execute(sql).fetchall()[0][0]
11051105
sql = f"SELECT path FROM STL_UNLOAD_LOG WHERE query={query_id};"
11061106
paths = [x[0].replace(" ", "") for x in _con.execute(sql).fetchall()]
1107-
_logger.debug(f"paths: {paths}")
1107+
_logger.debug("paths: %s", paths)
11081108
return paths
1109+
1110+
1111+
def _validate_engine(con: sqlalchemy.engine.Engine) -> None: # pragma: no cover
1112+
if not isinstance(con, sqlalchemy.engine.Engine):
1113+
raise exceptions.InvalidConnection(
1114+
"Invalid 'con' argument, please pass a "
1115+
"SQLAlchemy Engine. Use wr.db.get_engine(), "
1116+
"wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()"
1117+
)

0 commit comments

Comments
 (0)