Skip to content

Commit 53ac55e

Browse files
Add postgres upsert (#811)
* Add postgres upsert * Minor - Raising exception when no conflict cols Co-authored-by: jaidisido <[email protected]>
1 parent 907cea6 commit 53ac55e

File tree

5 files changed

+185
-9
lines changed

5 files changed

+185
-9
lines changed

awswrangler/_databases.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,9 @@ def convert_value_to_native_python_type(value: Any) -> Any:
256256
chunk_placeholders = ", ".join([f"({column_placeholders})" for _ in range(len(parameters_chunk))])
257257
flattened_chunk = [convert_value_to_native_python_type(value) for row in parameters_chunk for value in row]
258258
yield chunk_placeholders, flattened_chunk
259+
260+
261+
def validate_mode(mode: str, allowed_modes: List[str]) -> None:
262+
"""Check if mode is included in allowed_modes."""
263+
if mode not in allowed_modes:
264+
raise exceptions.InvalidArgumentValue(f"mode must be one of {', '.join(allowed_modes)}")

awswrangler/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def chunkify(lst: List[Any], num_chunks: int = 1, max_length: Optional[int] = No
227227
if not lst:
228228
return []
229229
n: int = num_chunks if max_length is None else int(math.ceil((float(len(lst)) / float(max_length))))
230-
np_chunks = np.array_split(lst, n) # type: ignore
230+
np_chunks = np.array_split(lst, n)
231231
return [arr.tolist() for arr in np_chunks if len(arr) > 0]
232232

233233

awswrangler/mysql.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@ def to_sql(
292292
Schema name
293293
mode : str
294294
Append, overwrite, upsert_duplicate_key, upsert_replace_into, upsert_distinct.
295-
append: Inserts new records into table
296-
overwrite: Drops table and recreates
295+
append: Inserts new records into table.
296+
overwrite: Drops table and recreates.
297297
upsert_duplicate_key: Performs an upsert using `ON DUPLICATE KEY` clause. Requires table schema to have
298298
defined keys, otherwise duplicate records will be inserted.
299299
upsert_replace_into: Performs upsert using `REPLACE INTO` clause. Less efficient and still requires the
@@ -340,17 +340,16 @@ def to_sql(
340340
"""
341341
if df.empty is True:
342342
raise exceptions.EmptyDataFrame()
343+
343344
mode = mode.strip().lower()
344-
modes = [
345+
allowed_modes = [
345346
"append",
346347
"overwrite",
347348
"upsert_replace_into",
348349
"upsert_duplicate_key",
349350
"upsert_distinct",
350351
]
351-
if mode not in modes:
352-
raise exceptions.InvalidArgumentValue(f"mode must be one of {', '.join(modes)}")
353-
352+
_db_utils.validate_mode(mode=mode, allowed_modes=allowed_modes)
354353
_validate_connection(con=con)
355354
try:
356355
with con.cursor() as cursor:

awswrangler/postgresql.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def to_sql(
277277
varchar_lengths: Optional[Dict[str, int]] = None,
278278
use_column_names: bool = False,
279279
chunksize: int = 200,
280+
upsert_conflict_columns: Optional[List[str]] = None,
280281
) -> None:
281282
"""Write records stored in a DataFrame into PostgreSQL.
282283
@@ -291,7 +292,11 @@ def to_sql(
291292
schema : str
292293
Schema name
293294
mode : str
294-
Append or overwrite.
295+
Append, overwrite or upsert.
296+
append: Inserts new records into table.
297+
overwrite: Drops table and recreates.
298+
upsert: Perform an upsert which checks for conflicts on columns given by `upsert_conflict_columns` and
299+
sets the new values on conflicts. Note that `upsert_conflict_columns` is required for this mode.
295300
index : bool
296301
True to store the DataFrame index as a column in the table,
297302
otherwise False to ignore it.
@@ -307,6 +312,9 @@ def to_sql(
307312
inserted into the database columns `col1` and `col3`.
308313
chunksize: int
309314
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
315+
upsert_conflict_columns: List[str], optional
316+
This parameter is only supported if `mode` is set top `upsert`. In this case conflicts for the given columns are
317+
checked for evaluating the upsert.
310318
311319
Returns
312320
-------
@@ -330,6 +338,12 @@ def to_sql(
330338
"""
331339
if df.empty is True:
332340
raise exceptions.EmptyDataFrame()
341+
342+
mode = mode.strip().lower()
343+
allowed_modes = ["append", "overwrite", "upsert"]
344+
_db_utils.validate_mode(mode=mode, allowed_modes=allowed_modes)
345+
if mode == "upsert" and not upsert_conflict_columns:
346+
raise exceptions.InvalidArgumentValue("<upsert_conflict_columns> needs to be set when using upsert mode.")
333347
_validate_connection(con=con)
334348
try:
335349
with con.cursor() as cursor:
@@ -347,13 +361,18 @@ def to_sql(
347361
df.reset_index(level=df.index.names, inplace=True)
348362
column_placeholders: str = ", ".join(["%s"] * len(df.columns))
349363
insertion_columns = ""
364+
upsert_str = ""
350365
if use_column_names:
351366
insertion_columns = f"({', '.join(df.columns)})"
367+
if mode == "upsert":
368+
upsert_columns = ", ".join(df.columns.map(lambda column: f"{column}=EXCLUDED.{column}"))
369+
conflict_columns = ", ".join(upsert_conflict_columns) # type: ignore
370+
upsert_str = f" ON CONFLICT ({conflict_columns}) DO UPDATE SET {upsert_columns}"
352371
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
353372
df=df, column_placeholders=column_placeholders, chunksize=chunksize
354373
)
355374
for placeholders, parameters in placeholder_parameter_pair_generator:
356-
sql: str = f'INSERT INTO "{schema}"."{table}" {insertion_columns} VALUES {placeholders}'
375+
sql: str = f'INSERT INTO "{schema}"."{table}" {insertion_columns} VALUES {placeholders}{upsert_str}'
357376
_logger.debug("sql: %s", sql)
358377
cursor.executemany(sql, (parameters,))
359378
con.commit()

tests/test_postgresql.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,155 @@ def test_dfs_are_equal_for_different_chunksizes(postgresql_table, postgresql_con
219219
df["c1"] = df["c1"].astype("string")
220220

221221
assert df.equals(df2)
222+
223+
224+
def test_upsert(postgresql_table, postgresql_con):
225+
create_table_sql = (
226+
f"CREATE TABLE public.{postgresql_table} "
227+
"(c0 varchar NULL PRIMARY KEY,"
228+
"c1 int NULL DEFAULT 42,"
229+
"c2 int NOT NULL);"
230+
)
231+
with postgresql_con.cursor() as cursor:
232+
cursor.execute(create_table_sql)
233+
postgresql_con.commit()
234+
235+
df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]})
236+
237+
with pytest.raises(wr.exceptions.InvalidArgumentValue):
238+
wr.postgresql.to_sql(
239+
df=df,
240+
con=postgresql_con,
241+
schema="public",
242+
table=postgresql_table,
243+
mode="upsert",
244+
upsert_conflict_columns=None,
245+
use_column_names=True,
246+
)
247+
248+
wr.postgresql.to_sql(
249+
df=df,
250+
con=postgresql_con,
251+
schema="public",
252+
table=postgresql_table,
253+
mode="upsert",
254+
upsert_conflict_columns=["c0"],
255+
use_column_names=True,
256+
)
257+
wr.postgresql.to_sql(
258+
df=df,
259+
con=postgresql_con,
260+
schema="public",
261+
table=postgresql_table,
262+
mode="upsert",
263+
upsert_conflict_columns=["c0"],
264+
use_column_names=True,
265+
)
266+
df2 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table)
267+
assert bool(len(df2) == 2)
268+
269+
wr.postgresql.to_sql(
270+
df=df,
271+
con=postgresql_con,
272+
schema="public",
273+
table=postgresql_table,
274+
mode="upsert",
275+
upsert_conflict_columns=["c0"],
276+
use_column_names=True,
277+
)
278+
df3 = pd.DataFrame({"c0": ["baz", "bar"], "c2": [3, 2]})
279+
wr.postgresql.to_sql(
280+
df=df3,
281+
con=postgresql_con,
282+
schema="public",
283+
table=postgresql_table,
284+
mode="upsert",
285+
upsert_conflict_columns=["c0"],
286+
use_column_names=True,
287+
)
288+
df4 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table)
289+
assert bool(len(df4) == 3)
290+
291+
df5 = pd.DataFrame({"c0": ["foo", "bar"], "c2": [4, 5]})
292+
wr.postgresql.to_sql(
293+
df=df5,
294+
con=postgresql_con,
295+
schema="public",
296+
table=postgresql_table,
297+
mode="upsert",
298+
upsert_conflict_columns=["c0"],
299+
use_column_names=True,
300+
)
301+
302+
df6 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table)
303+
assert bool(len(df6) == 3)
304+
assert bool(len(df6.loc[(df6["c0"] == "foo") & (df6["c2"] == 4)]) == 1)
305+
assert bool(len(df6.loc[(df6["c0"] == "bar") & (df6["c2"] == 5)]) == 1)
306+
307+
308+
def test_upsert_multiple_conflict_columns(postgresql_table, postgresql_con):
309+
create_table_sql = (
310+
f"CREATE TABLE public.{postgresql_table} "
311+
"(c0 varchar NULL PRIMARY KEY,"
312+
"c1 int NOT NULL,"
313+
"c2 int NOT NULL,"
314+
"UNIQUE (c1, c2));"
315+
)
316+
with postgresql_con.cursor() as cursor:
317+
cursor.execute(create_table_sql)
318+
postgresql_con.commit()
319+
320+
df = pd.DataFrame({"c0": ["foo", "bar"], "c1": [1, 2], "c2": [3, 4]})
321+
upsert_conflict_columns = ["c1", "c2"]
322+
323+
wr.postgresql.to_sql(
324+
df=df,
325+
con=postgresql_con,
326+
schema="public",
327+
table=postgresql_table,
328+
mode="upsert",
329+
upsert_conflict_columns=upsert_conflict_columns,
330+
use_column_names=True,
331+
)
332+
wr.postgresql.to_sql(
333+
df=df,
334+
con=postgresql_con,
335+
schema="public",
336+
table=postgresql_table,
337+
mode="upsert",
338+
upsert_conflict_columns=upsert_conflict_columns,
339+
use_column_names=True,
340+
)
341+
df2 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table)
342+
assert bool(len(df2) == 2)
343+
344+
df3 = pd.DataFrame({"c0": ["baz", "spam"], "c1": [1, 5], "c2": [3, 2]})
345+
wr.postgresql.to_sql(
346+
df=df3,
347+
con=postgresql_con,
348+
schema="public",
349+
table=postgresql_table,
350+
mode="upsert",
351+
upsert_conflict_columns=upsert_conflict_columns,
352+
use_column_names=True,
353+
)
354+
df4 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table)
355+
assert bool(len(df4) == 3)
356+
357+
df5 = pd.DataFrame({"c0": ["egg", "spam"], "c1": [2, 5], "c2": [4, 2]})
358+
wr.postgresql.to_sql(
359+
df=df5,
360+
con=postgresql_con,
361+
schema="public",
362+
table=postgresql_table,
363+
mode="upsert",
364+
upsert_conflict_columns=upsert_conflict_columns,
365+
use_column_names=True,
366+
)
367+
368+
df6 = wr.postgresql.read_sql_table(con=postgresql_con, schema="public", table=postgresql_table)
369+
df7 = pd.DataFrame({"c0": ["baz", "egg", "spam"], "c1": [1, 2, 5], "c2": [3, 4, 2]})
370+
df7["c0"] = df7["c0"].astype("string")
371+
df7["c1"] = df7["c1"].astype("Int64")
372+
df7["c2"] = df7["c2"].astype("Int64")
373+
assert df6.equals(df7)

0 commit comments

Comments
 (0)