Skip to content

Commit 13c7e70

Browse files
committed
Improve mysql and postgres tests tear down.
1 parent 73a0987 commit 13c7e70

File tree

3 files changed

+72
-36
lines changed

3 files changed

+72
-36
lines changed

awswrangler/db.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,29 @@ def _convert_params(sql: str, params: Optional[Union[List, Tuple, Dict]]) -> Lis
227227
return args
228228

229229

230+
def _read_parquet_iterator(
231+
paths: List[str],
232+
keep_files: bool,
233+
use_threads: bool,
234+
categories: List[str] = None,
235+
chunked: Union[bool, int] = True,
236+
boto3_session: Optional[boto3.Session] = None,
237+
s3_additional_kwargs: Optional[Dict[str, str]] = None,
238+
) -> Iterator[pd.DataFrame]:
239+
dfs: Iterator[pd.DataFrame] = s3.read_parquet(
240+
path=paths,
241+
categories=categories,
242+
chunked=chunked,
243+
dataset=False,
244+
use_threads=use_threads,
245+
boto3_session=boto3_session,
246+
s3_additional_kwargs=s3_additional_kwargs,
247+
)
248+
yield from dfs
249+
if keep_files is False:
250+
s3.delete_objects(path=paths, use_threads=use_threads, boto3_session=boto3_session)
251+
252+
230253
def to_sql(df: pd.DataFrame, con: sqlalchemy.engine.Engine, **pandas_kwargs) -> None:
231254
"""Write records stored in a DataFrame to a SQL database.
232255
@@ -1111,29 +1134,6 @@ def unload_redshift(
11111134
)
11121135

11131136

1114-
def _read_parquet_iterator(
1115-
paths: List[str],
1116-
keep_files: bool,
1117-
use_threads: bool,
1118-
categories: List[str] = None,
1119-
chunked: Union[bool, int] = True,
1120-
boto3_session: Optional[boto3.Session] = None,
1121-
s3_additional_kwargs: Optional[Dict[str, str]] = None,
1122-
) -> Iterator[pd.DataFrame]:
1123-
dfs: Iterator[pd.DataFrame] = s3.read_parquet(
1124-
path=paths,
1125-
categories=categories,
1126-
chunked=chunked,
1127-
dataset=False,
1128-
use_threads=use_threads,
1129-
boto3_session=boto3_session,
1130-
s3_additional_kwargs=s3_additional_kwargs,
1131-
)
1132-
yield from dfs
1133-
if keep_files is False:
1134-
s3.delete_objects(path=paths, use_threads=use_threads, boto3_session=boto3_session)
1135-
1136-
11371137
def unload_redshift_to_files(
11381138
sql: str,
11391139
path: str,

tests/conftest.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,30 @@ def path3(bucket):
199199

200200

201201
@pytest.fixture(scope="function")
202-
def redshift_table(databases_parameters):
202+
def redshift_table():
203203
name = f"tbl_{get_time_str_with_random_suffix()}"
204204
print(f"Table name: {name}")
205205
yield name
206206
engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift")
207207
with engine.connect() as con:
208208
con.execute(f"DROP TABLE IF EXISTS public.{name}")
209+
210+
211+
@pytest.fixture(scope="function")
212+
def postgresql_table():
213+
name = f"tbl_{get_time_str_with_random_suffix()}"
214+
print(f"Table name: {name}")
215+
yield name
216+
engine = wr.catalog.get_engine(connection="aws-data-wrangler-postgresql")
217+
with engine.connect() as con:
218+
con.execute(f"DROP TABLE IF EXISTS public.{name}")
219+
220+
221+
@pytest.fixture(scope="function")
222+
def mysql_table():
223+
name = f"tbl_{get_time_str_with_random_suffix()}"
224+
print(f"Table name: {name}")
225+
yield name
226+
engine = wr.catalog.get_engine(connection="aws-data-wrangler-mysql")
227+
with engine.connect() as con:
228+
con.execute(f"DROP TABLE IF EXISTS test.{name}")

tests/test_db.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717

1818

1919
@pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"])
20-
def test_sql(redshift_table, databases_parameters, db_type):
20+
def test_sql(redshift_table, postgresql_table, mysql_table, databases_parameters, db_type):
21+
if db_type == "postgresql":
22+
table = postgresql_table
23+
elif db_type == "mysql":
24+
table = mysql_table
25+
else:
26+
table = redshift_table
2127
df = get_df()
2228
if db_type == "redshift":
2329
df.drop(["binary"], axis=1, inplace=True)
@@ -26,7 +32,7 @@ def test_sql(redshift_table, databases_parameters, db_type):
2632
wr.db.to_sql(
2733
df=df,
2834
con=engine,
29-
name=redshift_table,
35+
name=table,
3036
schema=databases_parameters[db_type]["schema"],
3137
if_exists="replace",
3238
index=index,
@@ -36,7 +42,7 @@ def test_sql(redshift_table, databases_parameters, db_type):
3642
dtype={"iint32": sqlalchemy.types.Integer},
3743
)
3844
df = wr.db.read_sql_query(
39-
sql=f"SELECT * FROM {databases_parameters[db_type]['schema']}.{redshift_table}", con=engine
45+
sql=f"SELECT * FROM {databases_parameters[db_type]['schema']}.{table}", con=engine
4046
)
4147
ensure_data_types(df, has_list=False)
4248
engine = wr.db.get_engine(
@@ -49,7 +55,7 @@ def test_sql(redshift_table, databases_parameters, db_type):
4955
echo=False,
5056
)
5157
dfs = wr.db.read_sql_query(
52-
sql=f"SELECT * FROM {databases_parameters[db_type]['schema']}.{redshift_table}",
58+
sql=f"SELECT * FROM {databases_parameters[db_type]['schema']}.{table}",
5359
con=engine,
5460
chunksize=1,
5561
dtype={
@@ -76,7 +82,7 @@ def test_sql(redshift_table, databases_parameters, db_type):
7682
wr.db.to_sql(
7783
df=pd.DataFrame({"col0": [1, 2, 3]}, dtype="Int32"),
7884
con=engine,
79-
name=redshift_table,
85+
name=table,
8086
schema=databases_parameters[db_type]["schema"],
8187
if_exists="replace",
8288
index=True,
@@ -85,7 +91,7 @@ def test_sql(redshift_table, databases_parameters, db_type):
8591
schema = None
8692
if db_type == "postgresql":
8793
schema = databases_parameters[db_type]["schema"]
88-
df = wr.db.read_sql_table(con=engine, table=redshift_table, schema=schema, index_col="index")
94+
df = wr.db.read_sql_table(con=engine, table=table, schema=schema, index_col="index")
8995
assert df.shape == (3, 1)
9096

9197

@@ -373,8 +379,13 @@ def test_redshift_unload_extras(bucket, path, redshift_table, databases_paramete
373379

374380

375381
@pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"])
376-
def test_to_sql_cast(redshift_table, databases_parameters, db_type):
377-
table = redshift_table
382+
def test_to_sql_cast(redshift_table, postgresql_table, mysql_table, databases_parameters, db_type):
383+
if db_type == "postgresql":
384+
table = postgresql_table
385+
elif db_type == "mysql":
386+
table = mysql_table
387+
else:
388+
table = redshift_table
378389
schema = databases_parameters[db_type]["schema"]
379390
df = pd.DataFrame(
380391
{
@@ -403,8 +414,8 @@ def test_to_sql_cast(redshift_table, databases_parameters, db_type):
403414
assert df.equals(df2)
404415

405416

406-
def test_uuid(redshift_table, databases_parameters):
407-
table = redshift_table
417+
def test_uuid(postgresql_table, databases_parameters):
418+
table = postgresql_table
408419
schema = databases_parameters["postgresql"]["schema"]
409420
engine = wr.catalog.get_engine(connection="aws-data-wrangler-postgresql")
410421
df = pd.DataFrame(
@@ -436,8 +447,13 @@ def test_uuid(redshift_table, databases_parameters):
436447

437448

438449
@pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"])
439-
def test_null(redshift_table, databases_parameters, db_type):
440-
table = redshift_table
450+
def test_null(redshift_table, postgresql_table, mysql_table, databases_parameters, db_type):
451+
if db_type == "postgresql":
452+
table = postgresql_table
453+
elif db_type == "mysql":
454+
table = mysql_table
455+
else:
456+
table = redshift_table
441457
schema = databases_parameters[db_type]["schema"]
442458
engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-{db_type}")
443459
df = pd.DataFrame({"id": [1, 2, 3], "nothing": [None, None, None]})

0 commit comments

Comments
 (0)