66import logging
77import uuid
88from ssl import SSLContext
9- from typing import Any , Iterator , Literal , cast , overload
9+ from typing import TYPE_CHECKING , Any , Iterator , Literal , cast , overload
1010
1111import boto3
1212import pyarrow as pa
1313
1414import awswrangler .pandas as pd
15- from awswrangler import _data_types , _utils , exceptions
15+ from awswrangler import _data_types , _sql_utils , _utils , exceptions
1616from awswrangler import _databases as _db_utils
1717from awswrangler ._config import apply_configs
1818
19- pg8000 = _utils .import_optional_dependency ("pg8000" )
20- pg8000_native = _utils .import_optional_dependency ("pg8000.native" )
19+ if TYPE_CHECKING :
20+ try :
21+ import pg8000
22+ from pg8000 import native as pg8000_native
23+ except ImportError :
24+ pass
25+ else :
26+ pg8000 = _utils .import_optional_dependency ("pg8000" )
27+ pg8000_native = _utils .import_optional_dependency ("pg8000.native" )
2128
2229_logger : logging .Logger = logging .getLogger (__name__ )
2330
2431
32+ def _identifier (sql : str ) -> str :
33+ return _sql_utils .identifier (sql , sql_mode = "ansi" )
34+
35+
2536def _validate_connection (con : "pg8000.Connection" ) -> None :
2637 if not isinstance (con , pg8000 .Connection ):
2738 raise exceptions .InvalidConnection (
@@ -32,8 +43,8 @@ def _validate_connection(con: "pg8000.Connection") -> None:
3243
3344
3445def _drop_table (cursor : "pg8000.Cursor" , schema : str | None , table : str ) -> None :
35- schema_str = f"{ pg8000_native . identifier (schema )} ." if schema else ""
36- sql = f"DROP TABLE IF EXISTS { schema_str } { pg8000_native . identifier (table )} "
46+ schema_str = f"{ _identifier (schema )} ." if schema else ""
47+ sql = f"DROP TABLE IF EXISTS { schema_str } { _identifier (table )} "
3748 _logger .debug ("Drop table query:\n %s" , sql )
3849 cursor .execute (sql )
3950
@@ -71,15 +82,15 @@ def _create_table(
7182 varchar_lengths = varchar_lengths ,
7283 converter_func = _data_types .pyarrow2postgresql ,
7384 )
74- cols_str : str = "" .join ([f"{ pg8000_native . identifier (k )} { v } ,\n " for k , v in postgresql_types .items ()])[:- 2 ]
75- sql = f"CREATE TABLE IF NOT EXISTS { pg8000_native . identifier (schema )} .{ pg8000_native . identifier (table )} (\n { cols_str } )"
85+ cols_str : str = "" .join ([f"{ _identifier (k )} { v } ,\n " for k , v in postgresql_types .items ()])[:- 2 ]
86+ sql = f"CREATE TABLE IF NOT EXISTS { _identifier (schema )} .{ _identifier (table )} (\n { cols_str } )"
7687 _logger .debug ("Create table query:\n %s" , sql )
7788 cursor .execute (sql )
7889
7990
8091def _iterate_server_side_cursor (
8192 sql : str ,
82- con : Any ,
93+ con : "pg8000.Connection" ,
8394 chunksize : int ,
8495 index_col : str | list [str ] | None ,
8596 params : list [Any ] | tuple [Any , ...] | dict [Any , Any ] | None ,
@@ -97,16 +108,12 @@ def _iterate_server_side_cursor(
97108 """
98109 with con .cursor () as cursor :
99110 sscursor_name : str = f"c_{ uuid .uuid4 ().hex } "
100- cursor_args = _db_utils ._convert_params (
101- f"DECLARE { pg8000_native .identifier (sscursor_name )} CURSOR FOR { sql } " , params
102- )
111+ cursor_args = _db_utils ._convert_params (f"DECLARE { _identifier (sscursor_name )} CURSOR FOR { sql } " , params )
103112 cursor .execute (* cursor_args )
104113
105114 try :
106115 while True :
107- cursor .execute (
108- f"FETCH FORWARD { pg8000_native .literal (chunksize )} FROM { pg8000_native .identifier (sscursor_name )} "
109- )
116+ cursor .execute (f"FETCH FORWARD { pg8000_native .literal (chunksize )} FROM { _identifier (sscursor_name )} " )
110117 records = cursor .fetchall ()
111118
112119 if not records :
@@ -122,7 +129,7 @@ def _iterate_server_side_cursor(
122129 dtype_backend = dtype_backend ,
123130 )
124131 finally :
125- cursor .execute (f"CLOSE { pg8000_native . identifier (sscursor_name )} " )
132+ cursor .execute (f"CLOSE { _identifier (sscursor_name )} " )
126133
127134
128135@_utils .check_optional_dependency (pg8000 , "pg8000" )
@@ -466,9 +473,9 @@ def read_sql_table(
466473
467474 """
468475 sql : str = (
469- f"SELECT * FROM { pg8000_native . identifier (table )} "
476+ f"SELECT * FROM { _identifier (table )} "
470477 if schema is None
471- else f"SELECT * FROM { pg8000_native . identifier (schema )} .{ pg8000_native . identifier (table )} "
478+ else f"SELECT * FROM { _identifier (schema )} .{ _identifier (table )} "
472479 )
473480 return read_sql_query (
474481 sql = sql ,
@@ -586,7 +593,7 @@ def to_sql(
586593 if index :
587594 df .reset_index (level = df .index .names , inplace = True )
588595 column_placeholders : str = ", " .join (["%s" ] * len (df .columns ))
589- column_names = [pg8000_native . identifier (column ) for column in df .columns ]
596+ column_names = [_identifier (column ) for column in df .columns ]
590597 insertion_columns = ""
591598 upsert_str = ""
592599 if use_column_names :
@@ -602,7 +609,7 @@ def to_sql(
602609 df = df , column_placeholders = column_placeholders , chunksize = chunksize
603610 )
604611 for placeholders , parameters in placeholder_parameter_pair_generator :
605- sql : str = f"INSERT INTO { pg8000_native . identifier (schema )} .{ pg8000_native . identifier (table )} { insertion_columns } VALUES { placeholders } { upsert_str } "
612+ sql : str = f"INSERT INTO { _identifier (schema )} .{ _identifier (table )} { insertion_columns } VALUES { placeholders } { upsert_str } "
606613 _logger .debug ("sql: %s" , sql )
607614 cursor .executemany (sql , (parameters ,))
608615 con .commit ()
0 commit comments