Skip to content

Commit da01429

Browse files
committed
Improving databases internals.
1 parent 7562397 commit da01429

File tree

8 files changed

+119
-38
lines changed

8 files changed

+119
-38
lines changed

awswrangler/_databases.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,23 +116,28 @@ def read_sql_query(
116116
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
117117
"""Read SQL Query (generic)."""
118118
args = _convert_params(sql, params)
119-
with con.cursor() as cursor:
120-
cursor.execute(*args)
121-
cols_names: List[str] = [
122-
col[0].decode("utf-8") if isinstance(col[0], bytes) else col[0] for col in cursor.description
123-
]
124-
_logger.debug("cols_names: %s", cols_names)
125-
if chunksize is None:
126-
return _records2df(
127-
records=cast(List[Tuple[Any]], cursor.fetchall()),
128-
cols_names=cols_names,
129-
index=index_col,
130-
dtype=dtype,
131-
safe=safe,
119+
try:
120+
with con.cursor() as cursor:
121+
cursor.execute(*args)
122+
cols_names: List[str] = [
123+
col[0].decode("utf-8") if isinstance(col[0], bytes) else col[0] for col in cursor.description
124+
]
125+
_logger.debug("cols_names: %s", cols_names)
126+
if chunksize is None:
127+
return _records2df(
128+
records=cast(List[Tuple[Any]], cursor.fetchall()),
129+
cols_names=cols_names,
130+
index=index_col,
131+
dtype=dtype,
132+
safe=safe,
133+
)
134+
return _iterate_cursor(
135+
cursor=cursor, chunksize=chunksize, cols_names=cols_names, index=index_col, dtype=dtype, safe=safe
132136
)
133-
return _iterate_cursor(
134-
cursor=cursor, chunksize=chunksize, cols_names=cols_names, index=index_col, dtype=dtype, safe=safe
135-
)
137+
except Exception as ex:
138+
con.rollback()
139+
_logger.error(ex)
140+
raise
136141

137142

138143
def extract_parameters(df: pd.DataFrame) -> List[List[Any]]:

awswrangler/mysql.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def _validate_connection(con: pymysql.connections.Connection) -> None:
2626

2727

2828
def _drop_table(cursor: Cursor, schema: Optional[str], table: str) -> None:
29-
schema_str = f"{schema}." if schema else ""
30-
sql = f"DROP TABLE IF EXISTS {schema_str}{table}"
29+
schema_str = f"`{schema}`." if schema else ""
30+
sql = f"DROP TABLE IF EXISTS {schema_str}`{table}`"
3131
_logger.debug("Drop table query:\n%s", sql)
3232
cursor.execute(sql)
3333

@@ -61,7 +61,7 @@ def _create_table(
6161
converter_func=_data_types.pyarrow2mysql,
6262
)
6363
cols_str: str = "".join([f"`{k}` {v},\n" for k, v in mysql_types.items()])[:-2]
64-
sql = f"CREATE TABLE IF NOT EXISTS {schema}.{table} (\n" f"{cols_str}" f")"
64+
sql = f"CREATE TABLE IF NOT EXISTS `{schema}`.`{table}` (\n{cols_str})"
6565
_logger.debug("Create table query:\n%s", sql)
6666
cursor.execute(sql)
6767

@@ -246,7 +246,7 @@ def read_sql_table(
246246
>>> con.close()
247247
248248
"""
249-
sql: str = f"SELECT * FROM {table}" if schema is None else f"SELECT * FROM {schema}.{table}"
249+
sql: str = f"SELECT * FROM `{table}`" if schema is None else f"SELECT * FROM `{schema}`.`{table}`"
250250
return read_sql_query(
251251
sql=sql, con=con, index_col=index_col, params=params, chunksize=chunksize, dtype=dtype, safe=safe
252252
)
@@ -310,7 +310,6 @@ def to_sql(
310310
if df.empty is True:
311311
raise exceptions.EmptyDataFrame()
312312
_validate_connection(con=con)
313-
con.autocommit(True) # type: ignore
314313
try:
315314
with con.cursor() as cursor:
316315
_create_table(
@@ -326,7 +325,7 @@ def to_sql(
326325
if index:
327326
df.reset_index(level=df.index.names, inplace=True)
328327
placeholders: str = ", ".join(["%s"] * len(df.columns))
329-
sql: str = f"INSERT INTO {schema}.{table} VALUES ({placeholders})"
328+
sql: str = f"INSERT INTO `{schema}`.`{table}` VALUES ({placeholders})"
330329
_logger.debug("sql: %s", sql)
331330
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
332331
cursor.executemany(sql, parameters)

awswrangler/postgresql.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def _validate_connection(con: pg8000.Connection) -> None:
2525

2626

2727
def _drop_table(cursor: pg8000.Cursor, schema: Optional[str], table: str) -> None:
28-
schema_str = f"{schema}." if schema else ""
29-
sql = f"DROP TABLE IF EXISTS {schema_str}{table}"
28+
schema_str = f'"{schema}".' if schema else ""
29+
sql = f'DROP TABLE IF EXISTS {schema_str}"{table}"'
3030
_logger.debug("Drop table query:\n%s", sql)
3131
cursor.execute(sql)
3232

@@ -65,7 +65,7 @@ def _create_table(
6565
converter_func=_data_types.pyarrow2postgresql,
6666
)
6767
cols_str: str = "".join([f"{k} {v},\n" for k, v in postgresql_types.items()])[:-2]
68-
sql = f"CREATE TABLE IF NOT EXISTS {schema}.{table} (\n" f"{cols_str}" f")"
68+
sql = f'CREATE TABLE IF NOT EXISTS "{schema}"."{table}" (\n{cols_str})'
6969
_logger.debug("Create table query:\n%s", sql)
7070
cursor.execute(sql)
7171

@@ -250,7 +250,7 @@ def read_sql_table(
250250
>>> con.close()
251251
252252
"""
253-
sql: str = f"SELECT * FROM {table}" if schema is None else f"SELECT * FROM {schema}.{table}"
253+
sql: str = f'SELECT * FROM "{table}"' if schema is None else f'SELECT * FROM "{schema}"."{table}"'
254254
return read_sql_query(
255255
sql=sql, con=con, index_col=index_col, params=params, chunksize=chunksize, dtype=dtype, safe=safe
256256
)
@@ -314,7 +314,6 @@ def to_sql(
314314
if df.empty is True:
315315
raise exceptions.EmptyDataFrame()
316316
_validate_connection(con=con)
317-
con.autocommit = True
318317
try:
319318
with con.cursor() as cursor:
320319
_create_table(
@@ -330,7 +329,7 @@ def to_sql(
330329
if index:
331330
df.reset_index(level=df.index.names, inplace=True)
332331
placeholders: str = ", ".join(["%s"] * len(df.columns))
333-
sql: str = f"INSERT INTO {schema}.{table} VALUES ({placeholders})"
332+
sql: str = f'INSERT INTO "{schema}"."{table}" VALUES ({placeholders})'
334333
_logger.debug("sql: %s", sql)
335334
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
336335
cursor.executemany(sql, parameters)

awswrangler/redshift.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def _validate_connection(con: redshift_connector.Connection) -> None:
2929

3030

3131
def _drop_table(cursor: redshift_connector.Cursor, schema: Optional[str], table: str) -> None:
32-
schema_str = f"{schema}." if schema else ""
33-
sql = f"DROP TABLE IF EXISTS {schema_str}{table}"
32+
schema_str = f'"{schema}".' if schema else ""
33+
sql = f'DROP TABLE IF EXISTS {schema_str}"{table}"'
3434
_logger.debug("Drop table query:\n%s", sql)
3535
cursor.execute(sql)
3636

@@ -62,9 +62,9 @@ def _copy(
6262
schema: Optional[str] = None,
6363
) -> None:
6464
if schema is None:
65-
table_name: str = table
65+
table_name: str = f'"{table}"'
6666
else:
67-
table_name = f"{schema}.{table}"
67+
table_name = f'"{schema}"."{table}"'
6868
sql: str = f"COPY {table_name} FROM '{path}'\nIAM_ROLE '{iam_role}'\nFORMAT AS PARQUET"
6969
_logger.debug("copy query:\n%s", sql)
7070
cursor.execute(sql)
@@ -84,7 +84,7 @@ def _upsert(
8484
raise exceptions.InvalidRedshiftPrimaryKeys()
8585
equals_clause: str = f"{table}.%s = {temp_table}.%s"
8686
join_clause: str = " AND ".join([equals_clause % (pk, pk) for pk in primary_keys])
87-
sql: str = f"DELETE FROM {schema}.{table} USING {temp_table} WHERE {join_clause}"
87+
sql: str = f'DELETE FROM "{schema}"."{table}" USING {temp_table} WHERE {join_clause}'
8888
_logger.debug(sql)
8989
cursor.execute(sql)
9090
sql = f"INSERT INTO {schema}.{table} SELECT * FROM {temp_table}"
@@ -177,7 +177,7 @@ def _create_table(
177177
if mode == "upsert":
178178
guid: str = uuid.uuid4().hex
179179
temp_table: str = f"temp_redshift_{guid}"
180-
sql: str = f"CREATE TEMPORARY TABLE {temp_table} (LIKE {schema}.{table})"
180+
sql: str = f'CREATE TEMPORARY TABLE {temp_table} (LIKE "{schema}"."{table}")'
181181
_logger.debug(sql)
182182
cursor.execute(sql)
183183
return temp_table, None
@@ -217,7 +217,7 @@ def _create_table(
217217
distkey_str: str = f"\nDISTKEY({distkey})" if distkey and diststyle == "KEY" else ""
218218
sortkey_str: str = f"\n{sortstyle} SORTKEY({','.join(sortkey)})" if sortkey else ""
219219
sql = (
220-
f"CREATE TABLE IF NOT EXISTS {schema}.{table} (\n"
220+
f'CREATE TABLE IF NOT EXISTS "{schema}"."{table}" (\n'
221221
f"{cols_str}"
222222
f"{primary_keys_str}"
223223
f")\nDISTSTYLE {diststyle}"
@@ -538,7 +538,7 @@ def read_sql_table(
538538
>>> con.close()
539539
540540
"""
541-
sql: str = f"SELECT * FROM {table}" if schema is None else f"SELECT * FROM {schema}.{table}"
541+
sql: str = f'SELECT * FROM "{table}"' if schema is None else f'SELECT * FROM "{schema}"."{table}"'
542542
return read_sql_query(
543543
sql=sql, con=con, index_col=index_col, params=params, chunksize=chunksize, dtype=dtype, safe=safe
544544
)
@@ -627,6 +627,7 @@ def to_sql(
627627
if df.empty is True:
628628
raise exceptions.EmptyDataFrame()
629629
_validate_connection(con=con)
630+
autocommit_temp: bool = con.autocommit
630631
con.autocommit = False
631632
try:
632633
with con.cursor() as cursor:
@@ -650,8 +651,8 @@ def to_sql(
650651
if index:
651652
df.reset_index(level=df.index.names, inplace=True)
652653
placeholders: str = ", ".join(["%s"] * len(df.columns))
653-
schema_str = f"{created_schema}." if created_schema else ""
654-
sql: str = f"INSERT INTO {schema_str}{created_table} VALUES ({placeholders})"
654+
schema_str = f'"{created_schema}".' if created_schema else ""
655+
sql: str = f'INSERT INTO {schema_str}"{created_table}" VALUES ({placeholders})'
655656
_logger.debug("sql: %s", sql)
656657
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
657658
cursor.executemany(sql, parameters)
@@ -662,6 +663,8 @@ def to_sql(
662663
con.rollback()
663664
_logger.error(ex)
664665
raise
666+
finally:
667+
con.autocommit = autocommit_temp
665668

666669

667670
def unload_to_files(
@@ -1009,6 +1012,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
10091012
>>> con.close()
10101013
10111014
"""
1015+
autocommit_temp: bool = con.autocommit
10121016
con.autocommit = False
10131017
try:
10141018
with con.cursor() as cursor:
@@ -1047,6 +1051,8 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
10471051
con.rollback()
10481052
_logger.error(ex)
10491053
raise
1054+
finally:
1055+
con.autocommit = autocommit_temp
10501056

10511057

10521058
def copy( # pylint: disable=too-many-arguments

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def redshift_table():
208208
con = wr.redshift.connect("aws-data-wrangler-redshift")
209209
with con.cursor() as cursor:
210210
cursor.execute(f"DROP TABLE IF EXISTS public.{name}")
211+
con.commit()
211212
con.close()
212213

213214

@@ -219,6 +220,7 @@ def postgresql_table():
219220
con = wr.postgresql.connect("aws-data-wrangler-postgresql")
220221
with con.cursor() as cursor:
221222
cursor.execute(f"DROP TABLE IF EXISTS public.{name}")
223+
con.commit()
222224
con.close()
223225

224226

@@ -230,6 +232,7 @@ def mysql_table():
230232
con = wr.mysql.connect("aws-data-wrangler-mysql")
231233
with con.cursor() as cursor:
232234
cursor.execute(f"DROP TABLE IF EXISTS test.{name}")
235+
con.commit()
233236
con.close()
234237

235238

tests/test_mysql.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,26 @@ def test_decimal_cast(mysql_table):
148148
assert 3.88 <= df2.col1.sum() <= 3.89
149149
assert df2.col2.sum() == 2
150150
con.close()
151+
152+
153+
def test_read_retry():
154+
con = wr.mysql.connect(connection="aws-data-wrangler-mysql")
155+
try:
156+
wr.mysql.read_sql_query("ERROR", con)
157+
except: # noqa
158+
pass
159+
df = wr.mysql.read_sql_query("SELECT 1", con)
160+
assert df.shape == (1, 1)
161+
con.close()
162+
163+
164+
def test_table_name():
165+
df = pd.DataFrame({"col0": [1]})
166+
con = wr.mysql.connect(connection="aws-data-wrangler-mysql")
167+
wr.mysql.to_sql(df, con, "Test Name", "test")
168+
df = wr.mysql.read_sql_table(schema="test", table="Test Name")
169+
assert df.shape == (1, 1)
170+
with con.cursor() as cursor:
171+
cursor.execute("DROP TABLE `Test Name`")
172+
con.commit()
173+
con.close()

tests/test_postgresql.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,26 @@ def test_decimal_cast(postgresql_table):
148148
assert 3.88 <= df2.col1.sum() <= 3.89
149149
assert df2.col2.sum() == 2
150150
con.close()
151+
152+
153+
def test_read_retry():
154+
con = wr.postgresql.connect(connection="aws-data-wrangler-postgresql")
155+
try:
156+
wr.postgresql.read_sql_query("ERROR", con)
157+
except: # noqa
158+
pass
159+
df = wr.postgresql.read_sql_query("SELECT 1", con)
160+
assert df.shape == (1, 1)
161+
con.close()
162+
163+
164+
def test_table_name():
165+
df = pd.DataFrame({"col0": [1]})
166+
con = wr.postgresql.connect(connection="aws-data-wrangler-postgresql")
167+
wr.postgresql.to_sql(df, con, "Test Name", "public", mode="overwrite")
168+
df = wr.postgresql.read_sql_table(schema="public", con=con, table="Test Name")
169+
assert df.shape == (1, 1)
170+
with con.cursor() as cursor:
171+
cursor.execute('DROP TABLE "Test Name"')
172+
con.commit()
173+
con.close()

tests/test_redshift.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,3 +667,26 @@ def test_upsert(redshift_table):
667667
assert len(df.columns) == len(df4.columns)
668668

669669
con.close()
670+
671+
672+
def test_read_retry():
673+
con = wr.redshift.connect(connection="aws-data-wrangler-redshift")
674+
try:
675+
wr.redshift.read_sql_query("ERROR", con)
676+
except: # noqa
677+
pass
678+
df = wr.redshift.read_sql_query("SELECT 1", con)
679+
assert df.shape == (1, 1)
680+
con.close()
681+
682+
683+
def test_table_name():
684+
df = pd.DataFrame({"col0": [1]})
685+
con = wr.redshift.connect(connection="aws-data-wrangler-redshift")
686+
wr.redshift.to_sql(df, con, "Test Name", "public", mode="overwrite")
687+
df = wr.redshift.read_sql_table(schema="public", con=con, table="Test Name")
688+
assert df.shape == (1, 1)
689+
with con.cursor() as cursor:
690+
cursor.execute('DROP TABLE "Test Name"')
691+
con.commit()
692+
con.close()

0 commit comments

Comments
 (0)