Skip to content

Commit c957b16

Browse files
committed
Improve import speed #460
1 parent 19adde1 commit c957b16

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

awswrangler/_data_types.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import re
66
from decimal import Decimal
7-
from typing import Any, Dict, List, Match, Optional, Sequence, Tuple
7+
from typing import TYPE_CHECKING, Any, Dict, List, Match, Optional, Sequence, Tuple
88

99
import numpy as np
1010
import pandas as pd
@@ -13,11 +13,12 @@
1313
import sqlalchemy
1414
import sqlalchemy.dialects.mysql
1515
import sqlalchemy.dialects.postgresql
16-
import sqlalchemy_redshift.dialect
17-
from sqlalchemy.sql.visitors import VisitableType
1816

1917
from awswrangler import _utils, exceptions
2018

19+
if TYPE_CHECKING:
20+
from sqlalchemy.sql.visitors import VisitableType
21+
2122
_logger: logging.Logger = logging.getLogger(__name__)
2223

2324

@@ -210,7 +211,7 @@ def pyarrow2pandas_extension( # pylint: disable=too-many-branches,too-many-retu
210211

211212
def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-statements
212213
dtype: pa.DataType, db_type: str
213-
) -> Optional[VisitableType]:
214+
) -> Optional["VisitableType"]:
214215
"""Pyarrow to Athena data types conversion."""
215216
if pa.types.is_int8(dtype):
216217
return sqlalchemy.types.SmallInteger
@@ -228,6 +229,8 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta
228229
if db_type == "postgresql":
229230
return sqlalchemy.dialects.postgresql.DOUBLE_PRECISION
230231
if db_type == "redshift":
232+
import sqlalchemy_redshift.dialect # pylint: disable=import-outside-toplevel
233+
231234
return sqlalchemy_redshift.dialect.DOUBLE_PRECISION
232235
raise exceptions.InvalidDatabaseType(
233236
f"{db_type} is a invalid database type, please choose between postgresql, mysql and redshift."
@@ -509,14 +512,14 @@ def _cast_pandas_column(df: pd.DataFrame, col: str, current_type: str, desired_t
509512

510513

511514
def sqlalchemy_types_from_pandas(
512-
df: pd.DataFrame, db_type: str, dtype: Optional[Dict[str, VisitableType]] = None
513-
) -> Dict[str, VisitableType]:
515+
df: pd.DataFrame, db_type: str, dtype: Optional[Dict[str, "VisitableType"]] = None
516+
) -> Dict[str, "VisitableType"]:
514517
"""Extract the related SQLAlchemy data types from any Pandas DataFrame."""
515-
casts: Dict[str, VisitableType] = dtype if dtype is not None else {}
518+
casts: Dict[str, "VisitableType"] = dtype if dtype is not None else {}
516519
pa_columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas(
517520
df=df, index=False, ignore_cols=list(casts.keys())
518521
)
519-
sqlalchemy_columns_types: Dict[str, VisitableType] = {}
522+
sqlalchemy_columns_types: Dict[str, "VisitableType"] = {}
520523
for k, v in pa_columns_types.items():
521524
if v is None:
522525
sqlalchemy_columns_types[k] = casts[k]

awswrangler/catalog/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@
4848
)
4949

5050
__all__ = [
51-
"add_column"
51+
"add_column",
5252
"add_csv_partitions",
5353
"add_parquet_partitions",
5454
"does_table_exist",
55-
"delete_column"
55+
"delete_column",
5656
"drop_duplicated_columns",
5757
"extract_athena_types",
5858
"sanitize_column_name",

0 commit comments

Comments
 (0)