Skip to content

Commit 6eccdb4

Browse files
authored
Merge pull request #222 from awslabs/dev
Updating master branch
2 parents 87ca83f + 12d0f66 commit 6eccdb4

36 files changed

+3236
-321
lines changed

.github/workflows/static-checking.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,12 @@ jobs:
2424
uses: actions/setup-python@v1
2525
with:
2626
python-version: ${{ matrix.python-version }}
27-
- name: Install dependencies
28-
run: |
29-
python -m pip install --upgrade pip
30-
pip install -r requirements.txt
31-
pip install -r requirements-dev.txt
27+
- name: Setup Environment
28+
run: ./setup-dev-env.sh
3229
- name: CloudFormation Lint
3330
run: cfn-lint -t testing/cloudformation.yaml
3431
- name: Documentation Lint
35-
run: pydocstyle awswrangler/ --add-ignore=D204
32+
run: pydocstyle awswrangler/ --add-ignore=D204,D403
3633
- name: mypy check
3734
run: mypy awswrangler
3835
- name: Flake8 Lint

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ testing/*parameters-*.properties
138138
testing/*requirements*.txt
139139
testing/coverage/*
140140
building/*requirements*.txt
141+
building/arrow
142+
building/lambda/arrow
141143
/docs/coverage/
142144
/docs/build/
143145
/docs/source/_build/

.pylintrc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ disable=print-statement,
141141
comprehension-escape,
142142
C0330,
143143
C0103,
144-
W1202
144+
W1202,
145+
too-few-public-methods
145146

146147
# Enable the message, report, category or checker with the given id(s). You can
147148
# either give multiple identifier separated by comma (,) or put this option

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,18 @@
55

66
**NOTE**
77

8-
We just released a new major version `1.0` with breaking changes. Please make sure that all your old projects has dependencies frozen on the desired version (e.g. `pip install awswrangler==0.3.2`).
8+
Due the new major version `1.*.*` with breaking changes, please make sure that all your old projects has dependencies frozen on the desired version (e.g. `pip install awswrangler==0.3.2`).
99

1010
---
1111

1212
![AWS Data Wrangler](docs/source/_static/logo2.png?raw=true "AWS Data Wrangler")
1313

14-
[![Release](https://img.shields.io/badge/release-1.0.4-brightgreen.svg)](https://pypi.org/project/awswrangler/)
14+
[![Release](https://img.shields.io/badge/release-1.1.0-brightgreen.svg)](https://pypi.org/project/awswrangler/)
1515
[![Python Version](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-brightgreen.svg)](https://anaconda.org/conda-forge/awswrangler)
1616
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
1717
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
18-
[![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/)
19-
[![Average time to resolve an issue](http://isitmaintained.com/badge/resolution/awslabs/aws-data-wrangler.svg)](http://isitmaintained.com/project/awslabs/aws-data-wrangler "Average time to resolve an issue")
2018

19+
[![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/)
2120
[![Coverage](https://img.shields.io/badge/coverage-100%25-brightgreen.svg)](https://pypi.org/project/awswrangler/)
2221
![Static Checking](https://github.com/awslabs/aws-data-wrangler/workflows/Static%20Checking/badge.svg?branch=master)
2322
[![Documentation Status](https://readthedocs.org/projects/aws-data-wrangler/badge/?version=latest)](https://aws-data-wrangler.readthedocs.io/?badge=latest)
@@ -85,6 +84,9 @@ df = wr.db.read_sql_query("SELECT * FROM external_schema.my_table", con=engine)
8584
- [11 - CSV Datasets](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/11%20-%20CSV%20Datasets.ipynb)
8685
- [12 - CSV Crawler](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/12%20-%20CSV%20Crawler.ipynb)
8786
- [13 - Merging Datasets on S3](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/13%20-%20Merging%20Datasets%20on%20S3.ipynb)
87+
- [14 - PyTorch](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/14%20-%20PyTorch.ipynb)
88+
- [15 - EMR](https://github.com/awslabs/aws-data-wrangler/blob/dev/tutorials/15%20-%20EMR.ipynb)
89+
- [16 - EMR & Docker](https://github.com/awslabs/aws-data-wrangler/blob/dev/tutorials/16%20-%20EMR%20%26%20Docker.ipynb)
8890
- [**API Reference**](https://aws-data-wrangler.readthedocs.io/en/latest/api.html)
8991
- [Amazon S3](https://aws-data-wrangler.readthedocs.io/en/latest/api.html#amazon-s3)
9092
- [AWS Glue Catalog](https://aws-data-wrangler.readthedocs.io/en/latest/api.html#aws-glue-catalog)

awswrangler/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66
"""
77

88
import logging
9+
from importlib.util import find_spec
910

1011
from awswrangler import athena, catalog, cloudwatch, db, emr, exceptions, s3 # noqa
1112
from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa
13+
from awswrangler._utils import get_account_id # noqa
14+
15+
if find_spec("torch") and find_spec("torchvision") and find_spec("torchaudio") and find_spec("PIL"):
16+
from awswrangler import torch # noqa
1217

1318
logging.getLogger("awswrangler").addHandler(logging.NullHandler())

awswrangler/__metadata__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77

88
__title__ = "awswrangler"
99
__description__ = "Pandas on AWS."
10-
__version__ = "1.0.4"
10+
__version__ = "1.1.0"
1111
__license__ = "Apache License 2.0"

awswrangler/_data_types.py

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Internal (private) Data Types Module."""
22

33
import logging
4+
import re
45
from decimal import Decimal
5-
from typing import Dict, List, Optional, Tuple
6+
from typing import Any, Dict, List, Match, Optional, Sequence, Tuple
67

78
import pandas as pd # type: ignore
89
import pyarrow as pa # type: ignore
@@ -139,8 +140,10 @@ def pyarrow2athena(dtype: pa.DataType) -> str: # pylint: disable=too-many-branc
139140
return f"decimal({dtype.precision},{dtype.scale})"
140141
if pa.types.is_list(dtype):
141142
return f"array<{pyarrow2athena(dtype=dtype.value_type)}>"
142-
if pa.types.is_struct(dtype): # pragma: no cover
143-
return f"struct<{', '.join([f'{f.name}: {pyarrow2athena(dtype=f.type)}' for f in dtype])}>"
143+
if pa.types.is_struct(dtype):
144+
return f"struct<{', '.join([f'{f.name}:{pyarrow2athena(dtype=f.type)}' for f in dtype])}>"
145+
if pa.types.is_map(dtype): # pragma: no cover
146+
return f"map<{pyarrow2athena(dtype=dtype.key_type)},{pyarrow2athena(dtype=dtype.item_type)}>"
144147
if dtype == pa.null():
145148
raise exceptions.UndetectedType("We can not infer the data type from an entire null object column")
146149
raise exceptions.UnsupportedType(f"Unsupported Pyarrow type: {dtype}") # pragma: no cover
@@ -167,7 +170,7 @@ def pyarrow2pandas_extension( # pylint: disable=too-many-branches,too-many-retu
167170

168171
def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-statements
169172
dtype: pa.DataType, db_type: str
170-
) -> VisitableType:
173+
) -> Optional[VisitableType]:
171174
"""Pyarrow to Athena data types conversion."""
172175
if pa.types.is_int8(dtype):
173176
return sqlalchemy.types.SmallInteger
@@ -207,14 +210,14 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta
207210
return sqlalchemy.types.Date
208211
if pa.types.is_binary(dtype):
209212
if db_type == "redshift":
210-
raise exceptions.UnsupportedType(f"Binary columns are not supported for Redshift.") # pragma: no cover
213+
raise exceptions.UnsupportedType("Binary columns are not supported for Redshift.") # pragma: no cover
211214
return sqlalchemy.types.Binary
212215
if pa.types.is_decimal(dtype):
213216
return sqlalchemy.types.Numeric(precision=dtype.precision, scale=dtype.scale)
214217
if pa.types.is_dictionary(dtype):
215218
return pyarrow2sqlalchemy(dtype=dtype.value_type, db_type=db_type)
216219
if dtype == pa.null(): # pragma: no cover
217-
raise exceptions.UndetectedType("We can not infer the data type from an entire null object column")
220+
return None
218221
raise exceptions.UnsupportedType(f"Unsupported Pyarrow type: {dtype}") # pragma: no cover
219222

220223

@@ -243,12 +246,23 @@ def pyarrow_types_from_pandas(
243246
else:
244247
cols.append(name)
245248

246-
# Filling cols_dtypes and indexes
249+
# Filling cols_dtypes
250+
for col in cols:
251+
_logger.debug("Inferring PyArrow type from column: %s", col)
252+
try:
253+
schema: pa.Schema = pa.Schema.from_pandas(df=df[[col]], preserve_index=False)
254+
except pa.ArrowInvalid as ex: # pragma: no cover
255+
cols_dtypes[col] = process_not_inferred_dtype(ex)
256+
else:
257+
cols_dtypes[col] = schema.field(col).type
258+
259+
# Filling indexes
247260
indexes: List[str] = []
248-
for field in pa.Schema.from_pandas(df=df[cols], preserve_index=index):
249-
name = str(field.name)
250-
cols_dtypes[name] = field.type
251-
if (name not in df.columns) and (index is True):
261+
if index is True:
262+
for field in pa.Schema.from_pandas(df=df[[]], preserve_index=True):
263+
name = str(field.name)
264+
_logger.debug("Inferring PyArrow type from index: %s", name)
265+
cols_dtypes[name] = field.type
252266
indexes.append(name)
253267

254268
# Merging Index
@@ -257,10 +271,43 @@ def pyarrow_types_from_pandas(
257271
# Filling schema
258272
columns_types: Dict[str, pa.DataType]
259273
columns_types = {n: cols_dtypes[n] for n in sorted_cols}
260-
_logger.debug(f"columns_types: {columns_types}")
274+
_logger.debug("columns_types: %s", columns_types)
261275
return columns_types
262276

263277

278+
def process_not_inferred_dtype(ex: pa.ArrowInvalid) -> pa.DataType:
279+
"""Infer data type from PyArrow inference exception."""
280+
ex_str = str(ex)
281+
_logger.debug("PyArrow was not able to infer data type:\n%s", ex_str)
282+
match: Optional[Match] = re.search(
283+
pattern="Could not convert (.*) with type (.*): did not recognize "
284+
"Python value type when inferring an Arrow data type",
285+
string=ex_str,
286+
)
287+
if match is None:
288+
raise ex # pragma: no cover
289+
groups: Optional[Sequence[str]] = match.groups()
290+
if groups is None:
291+
raise ex # pragma: no cover
292+
if len(groups) != 2:
293+
raise ex # pragma: no cover
294+
_logger.debug("groups: %s", groups)
295+
type_str: str = groups[1]
296+
if type_str == "UUID":
297+
return pa.string()
298+
raise ex # pragma: no cover
299+
300+
301+
def process_not_inferred_array(ex: pa.ArrowInvalid, values: Any) -> pa.Array:
302+
"""Infer `pyarrow.array` from PyArrow inference exception."""
303+
dtype = process_not_inferred_dtype(ex=ex)
304+
if dtype == pa.string():
305+
array: pa.Array = pa.array(obj=[str(x) for x in values], type=dtype, safe=True)
306+
else:
307+
raise ex # pragma: no cover
308+
return array
309+
310+
264311
def athena_types_from_pandas(
265312
df: pd.DataFrame, index: bool, dtype: Optional[Dict[str, str]] = None, index_left: bool = False
266313
) -> Dict[str, str]:
@@ -275,7 +322,7 @@ def athena_types_from_pandas(
275322
athena_columns_types[k] = casts[k]
276323
else:
277324
athena_columns_types[k] = pyarrow2athena(dtype=v)
278-
_logger.debug(f"athena_columns_types: {athena_columns_types}")
325+
_logger.debug("athena_columns_types: %s", athena_columns_types)
279326
return athena_columns_types
280327

281328

@@ -315,7 +362,7 @@ def pyarrow_schema_from_pandas(
315362
if (k in df.columns) and (k not in ignore):
316363
columns_types[k] = athena2pyarrow(v)
317364
columns_types = {k: v for k, v in columns_types.items() if v is not None}
318-
_logger.debug(f"columns_types: {columns_types}")
365+
_logger.debug("columns_types: %s", columns_types)
319366
return pa.schema(fields=columns_types)
320367

321368

@@ -324,11 +371,11 @@ def athena_types_from_pyarrow_schema(
324371
) -> Tuple[Dict[str, str], Optional[Dict[str, str]]]:
325372
"""Extract the related Athena data types from any PyArrow Schema considering possible partitions."""
326373
columns_types: Dict[str, str] = {str(f.name): pyarrow2athena(dtype=f.type) for f in schema}
327-
_logger.debug(f"columns_types: {columns_types}")
374+
_logger.debug("columns_types: %s", columns_types)
328375
partitions_types: Optional[Dict[str, str]] = None
329376
if partitions is not None:
330377
partitions_types = {p.name: pyarrow2athena(p.dictionary.type) for p in partitions}
331-
_logger.debug(f"partitions_types: {partitions_types}")
378+
_logger.debug("partitions_types: %s", partitions_types)
332379
return columns_types, partitions_types
333380

334381

@@ -372,7 +419,7 @@ def sqlalchemy_types_from_pandas(
372419
df: pd.DataFrame, db_type: str, dtype: Optional[Dict[str, VisitableType]] = None
373420
) -> Dict[str, VisitableType]:
374421
"""Extract the related SQLAlchemy data types from any Pandas DataFrame."""
375-
casts: Dict[str, VisitableType] = dtype if dtype else {}
422+
casts: Dict[str, VisitableType] = dtype if dtype is not None else {}
376423
pa_columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas(
377424
df=df, index=False, ignore_cols=list(casts.keys())
378425
)
@@ -382,5 +429,5 @@ def sqlalchemy_types_from_pandas(
382429
sqlalchemy_columns_types[k] = casts[k]
383430
else:
384431
sqlalchemy_columns_types[k] = pyarrow2sqlalchemy(dtype=v, db_type=db_type)
385-
_logger.debug(f"sqlalchemy_columns_types: {sqlalchemy_columns_types}")
432+
_logger.debug("sqlalchemy_columns_types: %s", sqlalchemy_columns_types)
386433
return sqlalchemy_columns_types

awswrangler/_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,16 @@ def ensure_postgresql_casts():
166166
def get_directory(path: str) -> str:
167167
"""Extract directory path."""
168168
return path.rsplit(sep="/", maxsplit=1)[0] + "/"
169+
170+
171+
def get_account_id(boto3_session: Optional[boto3.Session] = None) -> str:
172+
"""Get Account ID."""
173+
session: boto3.Session = ensure_session(session=boto3_session)
174+
return client(service_name="sts", session=session).get_caller_identity().get("Account")
175+
176+
177+
def get_region_from_subnet(subnet_id: str, boto3_session: Optional[boto3.Session] = None) -> str:
178+
"""Extract region from Subnet ID."""
179+
session: boto3.Session = ensure_session(session=boto3_session)
180+
client_ec2: boto3.client = client(service_name="ec2", session=session)
181+
return client_ec2.describe_subnets(SubnetIds=[subnet_id])["Subnets"][0]["AvailabilityZone"][:9]

0 commit comments

Comments
 (0)