Skip to content

Commit 5b8657a

Browse files
committed
Improving data type casting.
1 parent c55f18c commit 5b8657a

File tree

2 files changed

+76
-44
lines changed

2 files changed

+76
-44
lines changed

awswrangler/_data_types.py

Lines changed: 64 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import sqlalchemy_redshift.dialect # type: ignore
1515
from sqlalchemy.sql.visitors import VisitableType # type: ignore
1616

17-
from awswrangler import exceptions
17+
from awswrangler import _utils, exceptions
1818

1919
_logger: logging.Logger = logging.getLogger(__name__)
2020

@@ -44,14 +44,14 @@ def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-retur
4444
return pa.date32()
4545
if dtype in ("binary" or "varbinary"):
4646
return pa.binary()
47-
if dtype.startswith("decimal"):
47+
if dtype.startswith("decimal") is True:
4848
precision, scale = dtype.replace("decimal(", "").replace(")", "").split(sep=",")
4949
return pa.decimal128(precision=int(precision), scale=int(scale))
50-
if dtype.startswith("array"):
51-
return pa.large_list(athena2pyarrow(dtype=dtype[6:-1]))
52-
if dtype.startswith("struct"):
50+
if dtype.startswith("array") is True:
51+
return pa.list_(value_type=athena2pyarrow(dtype=dtype[6:-1]), list_size=-1)
52+
if dtype.startswith("struct") is True:
5353
return pa.struct([(f.split(":", 1)[0], athena2pyarrow(f.split(":", 1)[1])) for f in dtype[7:-1].split(",")])
54-
if dtype.startswith("map"): # pragma: no cover
54+
if dtype.startswith("map") is True: # pragma: no cover
5555
return pa.map_(athena2pyarrow(dtype[4:-1].split(",", 1)[0]), athena2pyarrow(dtype[4:-1].split(",", 1)[1]))
5656
raise exceptions.UnsupportedType(f"Unsupported Athena type: {dtype}") # pragma: no cover
5757

@@ -396,7 +396,7 @@ def pyarrow_schema_from_pandas(
396396
)
397397
for k, v in casts.items():
398398
if (k in df.columns) and (k not in ignore):
399-
columns_types[k] = athena2pyarrow(v)
399+
columns_types[k] = athena2pyarrow(dtype=v)
400400
columns_types = {k: v for k, v in columns_types.items() if v is not None}
401401
_logger.debug("columns_types: %s", columns_types)
402402
return pa.schema(fields=columns_types)
@@ -417,47 +417,67 @@ def athena_types_from_pyarrow_schema(
417417

418418
def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd.DataFrame:
419419
"""Cast columns in a Pandas DataFrame."""
420+
mutable_ensured: bool = False
420421
for col, athena_type in dtype.items():
421422
if (
422423
(col in df.columns)
423-
and (not athena_type.startswith("array"))
424-
and (not athena_type.startswith("struct"))
425-
and (not athena_type.startswith("map"))
424+
and (athena_type.startswith("array") is False)
425+
and (athena_type.startswith("struct") is False)
426+
and (athena_type.startswith("map") is False)
426427
):
427-
pandas_type: str = athena2pandas(dtype=athena_type)
428-
if pandas_type == "datetime64":
429-
df[col] = pd.to_datetime(df[col])
430-
elif pandas_type == "date":
431-
df[col] = pd.to_datetime(df[col]).dt.date.replace(to_replace={pd.NaT: None})
432-
elif pandas_type == "bytes":
433-
df[col] = df[col].astype("string").str.encode(encoding="utf-8").replace(to_replace={pd.NA: None})
434-
elif pandas_type == "decimal":
435-
df[col] = (
436-
df[col]
437-
.astype("string")
438-
.apply(lambda x: Decimal(str(x)) if str(x) not in ("", "none", "None", " ", "<NA>") else None)
439-
)
440-
elif pandas_type == "string":
441-
curr_type: str = str(df[col].dtypes)
442-
if curr_type.lower().startswith("int") is True:
443-
df[col] = df[col].astype(str).astype("string")
444-
elif curr_type.startswith("float") is True:
445-
df[col] = df[col].astype(str).astype("string")
446-
elif curr_type in ("object", "category"):
447-
df[col] = df[col].astype(str).astype("string")
448-
else:
449-
df[col] = df[col].astype("string")
450-
else:
451-
try:
452-
df[col] = df[col].astype(pandas_type)
453-
except TypeError as ex:
454-
if "object cannot be converted to an IntegerDtype" not in str(ex):
455-
raise ex # pragma: no cover
456-
df[col] = (
457-
df[col]
458-
.apply(lambda x: int(x) if str(x) not in ("", "none", "None", " ", "<NA>") else None)
459-
.astype(pandas_type)
460-
)
428+
desired_type: str = athena2pandas(dtype=athena_type)
429+
current_type: str = _normalize_pandas_dtype_name(dtype=str(df[col].dtypes))
430+
if desired_type != current_type: # Needs conversion
431+
_logger.debug("current_type: %s -> desired_type: %s", current_type, desired_type)
432+
if mutable_ensured is False:
433+
df = _utils.ensure_df_is_mutable(df=df)
434+
mutable_ensured = True
435+
_cast_pandas_column(df=df, col=col, current_type=current_type, desired_type=desired_type)
436+
437+
return df
438+
439+
440+
def _normalize_pandas_dtype_name(dtype: str) -> str:
441+
if dtype.startswith("datetime64") is True:
442+
return "datetime64"
443+
if dtype.startswith("decimal") is True:
444+
return "decimal" # pragma: no cover
445+
return dtype
446+
447+
448+
def _cast_pandas_column(df: pd.DataFrame, col: str, current_type: str, desired_type: str) -> pd.DataFrame:
449+
if desired_type == "datetime64":
450+
df[col] = pd.to_datetime(df[col])
451+
elif desired_type == "date":
452+
df[col] = pd.to_datetime(df[col]).dt.date.replace(to_replace={pd.NaT: None})
453+
elif desired_type == "bytes":
454+
df[col] = df[col].astype("string").str.encode(encoding="utf-8").replace(to_replace={pd.NA: None})
455+
elif desired_type == "decimal":
456+
df[col] = (
457+
df[col]
458+
.astype("string")
459+
.apply(lambda x: Decimal(str(x)) if str(x) not in ("", "none", "None", " ", "<NA>") else None)
460+
)
461+
elif desired_type == "string":
462+
if current_type.lower().startswith("int") is True:
463+
df[col] = df[col].astype(str).astype("string")
464+
elif current_type.startswith("float") is True:
465+
df[col] = df[col].astype(str).astype("string")
466+
elif current_type in ("object", "category"):
467+
df[col] = df[col].astype(str).astype("string")
468+
else:
469+
df[col] = df[col].astype("string")
470+
else:
471+
try:
472+
df[col] = df[col].astype(desired_type)
473+
except TypeError as ex:
474+
if "object cannot be converted to an IntegerDtype" not in str(ex):
475+
raise ex # pragma: no cover
476+
df[col] = (
477+
df[col]
478+
.apply(lambda x: int(x) if str(x) not in ("", "none", "None", " ", "<NA>") else None)
479+
.astype(desired_type)
480+
)
461481
return df
462482

463483

awswrangler/_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import boto3 # type: ignore
1111
import botocore.config # type: ignore
1212
import numpy as np # type: ignore
13+
import pandas as pd # type: ignore
1314
import psycopg2 # type: ignore
1415
import s3fs # type: ignore
1516

@@ -293,3 +294,14 @@ def list_sampling(lst: List[Any], sampling: float) -> List[Any]:
293294
_logger.debug("sampling: %s", sampling)
294295
_logger.debug("num_samples: %s", num_samples)
295296
return random.sample(population=lst, k=num_samples)
297+
298+
299+
def ensure_df_is_mutable(df: pd.DataFrame) -> pd.DataFrame:
300+
"""Ensure that all columns has the writeable flag True."""
301+
columns: List[str] = df.columns.to_list()
302+
for column in columns:
303+
if hasattr(df[column].values, "flags") is True:
304+
if df[column].values.flags.writeable is False: # pragma: no cover
305+
df = df.copy(deep=True)
306+
break
307+
return df

0 commit comments

Comments
 (0)