Skip to content

Commit fbea66b

Browse files
Add flag for using column names in to_sql (#583)
* Add flag for using column names in to_sql * Fix redshift test * Change redshift schema Co-authored-by: jaidisido <[email protected]>
1 parent 7d9408a commit fbea66b

File tree

8 files changed

+142
-4
lines changed

8 files changed

+142
-4
lines changed

awswrangler/mysql.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def to_sql(
266266
index: bool = False,
267267
dtype: Optional[Dict[str, str]] = None,
268268
varchar_lengths: Optional[Dict[str, int]] = None,
269+
use_column_names: bool = False,
269270
) -> None:
270271
"""Write records stored in a DataFrame into MySQL.
271272
@@ -290,6 +291,10 @@ def to_sql(
290291
(e.g. {'col name': 'TEXT', 'col2 name': 'FLOAT'})
291292
varchar_lengths : Dict[str, int], optional
292293
Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}).
294+
use_column_names: bool
295+
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
296+
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
297+
inserted into the database columns `col1` and `col3`.
293298
294299
Returns
295300
-------
@@ -329,7 +334,10 @@ def to_sql(
329334
if index:
330335
df.reset_index(level=df.index.names, inplace=True)
331336
placeholders: str = ", ".join(["%s"] * len(df.columns))
332-
sql: str = f"INSERT INTO `{schema}`.`{table}` VALUES ({placeholders})"
337+
insertion_columns = ""
338+
if use_column_names:
339+
insertion_columns = f"({', '.join(df.columns)})"
340+
sql: str = f"INSERT INTO `{schema}`.`{table}` {insertion_columns} VALUES ({placeholders})"
333341
_logger.debug("sql: %s", sql)
334342
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
335343
cursor.executemany(sql, parameters)

awswrangler/postgresql.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def to_sql(
272272
index: bool = False,
273273
dtype: Optional[Dict[str, str]] = None,
274274
varchar_lengths: Optional[Dict[str, int]] = None,
275+
use_column_names: bool = False,
275276
) -> None:
276277
"""Write records stored in a DataFrame into PostgreSQL.
277278
@@ -296,6 +297,10 @@ def to_sql(
296297
(e.g. {'col name': 'TEXT', 'col2 name': 'FLOAT'})
297298
varchar_lengths : Dict[str, int], optional
298299
Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}).
300+
use_column_names: bool
301+
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
302+
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
303+
inserted into the database columns `col1` and `col3`.
299304
300305
Returns
301306
-------
@@ -335,7 +340,10 @@ def to_sql(
335340
if index:
336341
df.reset_index(level=df.index.names, inplace=True)
337342
placeholders: str = ", ".join(["%s"] * len(df.columns))
338-
sql: str = f'INSERT INTO "{schema}"."{table}" VALUES ({placeholders})'
343+
insertion_columns = ""
344+
if use_column_names:
345+
insertion_columns = f"({', '.join(df.columns)})"
346+
sql: str = f'INSERT INTO "{schema}"."{table}" {insertion_columns} VALUES ({placeholders})'
339347
_logger.debug("sql: %s", sql)
340348
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
341349
cursor.executemany(sql, parameters)

awswrangler/redshift.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,7 @@ def to_sql(
644644
primary_keys: Optional[List[str]] = None,
645645
varchar_lengths_default: int = 256,
646646
varchar_lengths: Optional[Dict[str, int]] = None,
647+
use_column_names: bool = False,
647648
) -> None:
648649
"""Write records stored in a DataFrame into Redshift.
649650
@@ -688,6 +689,10 @@ def to_sql(
688689
The size that will be set for all VARCHAR columns not specified with varchar_lengths.
689690
varchar_lengths : Dict[str, int], optional
690691
Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}).
692+
use_column_names: bool
693+
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
694+
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
695+
inserted into the database columns `col1` and `col3`.
691696
692697
Returns
693698
-------
@@ -737,7 +742,10 @@ def to_sql(
737742
df.reset_index(level=df.index.names, inplace=True)
738743
placeholders: str = ", ".join(["%s"] * len(df.columns))
739744
schema_str = f'"{created_schema}".' if created_schema else ""
740-
sql: str = f'INSERT INTO {schema_str}"{created_table}" VALUES ({placeholders})'
745+
insertion_columns = ""
746+
if use_column_names:
747+
insertion_columns = f"({', '.join(df.columns)})"
748+
sql: str = f'INSERT INTO {schema_str}"{created_table}" {insertion_columns} VALUES ({placeholders})'
741749
_logger.debug("sql: %s", sql)
742750
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
743751
cursor.executemany(sql, parameters)

awswrangler/sqlserver.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def to_sql(
290290
index: bool = False,
291291
dtype: Optional[Dict[str, str]] = None,
292292
varchar_lengths: Optional[Dict[str, int]] = None,
293+
use_column_names: bool = False,
293294
) -> None:
294295
"""Write records stored in a DataFrame into Microsoft SQL Server.
295296
@@ -314,6 +315,10 @@ def to_sql(
314315
(e.g. {'col name': 'TEXT', 'col2 name': 'FLOAT'})
315316
varchar_lengths : Dict[str, int], optional
316317
Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}).
318+
use_column_names: bool
319+
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
320+
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
321+
inserted into the database columns `col1` and `col3`.
317322
318323
Returns
319324
-------
@@ -354,7 +359,10 @@ def to_sql(
354359
df.reset_index(level=df.index.names, inplace=True)
355360
placeholders: str = ", ".join(["?"] * len(df.columns))
356361
table_identifier = _get_table_identifier(schema, table)
357-
sql: str = f"INSERT INTO {table_identifier} VALUES ({placeholders})"
362+
insertion_columns = ""
363+
if use_column_names:
364+
insertion_columns = f"({', '.join(df.columns)})"
365+
sql: str = f"INSERT INTO {table_identifier} {insertion_columns} VALUES ({placeholders})"
358366
_logger.debug("sql: %s", sql)
359367
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
360368
cursor.executemany(sql, parameters)

tests/test_mysql.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,29 @@ def test_connect_secret_manager(dbname):
180180
df = wr.mysql.read_sql_query("SELECT 1", con=con)
181181
con.close()
182182
assert df.shape == (1, 1)
183+
184+
185+
def test_insert_with_column_names(mysql_table):
186+
con = wr.mysql.connect(connection="aws-data-wrangler-mysql")
187+
create_table_sql = (
188+
f"CREATE TABLE test.{mysql_table} " "(c0 varchar(100) NULL, " "c1 INT DEFAULT 42 NULL, " "c2 INT NOT NULL);"
189+
)
190+
with con.cursor() as cursor:
191+
cursor.execute(create_table_sql)
192+
con.commit()
193+
194+
df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]})
195+
196+
with pytest.raises(pymysql.err.OperationalError):
197+
wr.mysql.to_sql(df=df, con=con, schema="test", table=mysql_table, mode="append", use_column_names=False)
198+
199+
wr.mysql.to_sql(df=df, con=con, schema="test", table=mysql_table, mode="append", use_column_names=True)
200+
201+
df2 = wr.mysql.read_sql_table(con=con, schema="test", table=mysql_table)
202+
203+
df["c1"] = 42
204+
df["c0"] = df["c0"].astype("string")
205+
df["c1"] = df["c1"].astype("Int64")
206+
df["c2"] = df["c2"].astype("Int64")
207+
df = df.reindex(sorted(df.columns), axis=1)
208+
assert df.equals(df2)

tests/test_postgresql.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,31 @@ def test_connect_secret_manager(dbname):
180180
df = wr.postgresql.read_sql_query("SELECT 1", con=con)
181181
con.close()
182182
assert df.shape == (1, 1)
183+
184+
185+
def test_insert_with_column_names(postgresql_table):
186+
con = wr.postgresql.connect(connection="aws-data-wrangler-postgresql")
187+
create_table_sql = (
188+
f"CREATE TABLE public.{postgresql_table} " "(c0 varchar NULL," "c1 int NULL DEFAULT 42," "c2 int NOT NULL);"
189+
)
190+
with con.cursor() as cursor:
191+
cursor.execute(create_table_sql)
192+
con.commit()
193+
194+
df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]})
195+
196+
with pytest.raises(pg8000.exceptions.ProgrammingError):
197+
wr.postgresql.to_sql(
198+
df=df, con=con, schema="public", table=postgresql_table, mode="append", use_column_names=False
199+
)
200+
201+
wr.postgresql.to_sql(df=df, con=con, schema="public", table=postgresql_table, mode="append", use_column_names=True)
202+
203+
df2 = wr.postgresql.read_sql_table(con=con, schema="public", table=postgresql_table)
204+
205+
df["c1"] = 42
206+
df["c0"] = df["c0"].astype("string")
207+
df["c1"] = df["c1"].astype("Int64")
208+
df["c2"] = df["c2"].astype("Int64")
209+
df = df.reindex(sorted(df.columns), axis=1)
210+
assert df.equals(df2)

tests/test_redshift.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,3 +911,29 @@ def test_failed_keep_files(path, redshift_table, databases_parameters):
911911
varchar_lengths={"c1": 2},
912912
)
913913
assert len(wr.s3.list_objects(path)) == 0
914+
915+
916+
def test_insert_with_column_names(redshift_table):
917+
con = wr.redshift.connect(connection="aws-data-wrangler-redshift")
918+
create_table_sql = (
919+
f"CREATE TABLE public.{redshift_table} " "(c0 varchar(100), " "c1 integer default 42, " "c2 integer not null);"
920+
)
921+
with con.cursor() as cursor:
922+
cursor.execute(create_table_sql)
923+
con.commit()
924+
925+
df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]})
926+
927+
with pytest.raises(redshift_connector.error.ProgrammingError):
928+
wr.redshift.to_sql(df=df, con=con, schema="public", table=redshift_table, mode="append", use_column_names=False)
929+
930+
wr.redshift.to_sql(df=df, con=con, schema="public", table=redshift_table, mode="append", use_column_names=True)
931+
932+
df2 = wr.redshift.read_sql_table(con=con, schema="public", table=redshift_table)
933+
934+
df["c1"] = 42
935+
df["c0"] = df["c0"].astype("string")
936+
df["c1"] = df["c1"].astype("Int64")
937+
df["c2"] = df["c2"].astype("Int64")
938+
df = df.reindex(sorted(df.columns), axis=1)
939+
assert df.equals(df2)

tests/test_sqlserver.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,29 @@ def test_connect_secret_manager(dbname):
194194
assert df.shape == (1, 1)
195195
except boto3.client("secretsmanager").exceptions.ResourceNotFoundException:
196196
pass # Workaround for secretmanager inconsistance
197+
198+
199+
def test_insert_with_column_names(sqlserver_table):
200+
con = wr.sqlserver.connect(connection="aws-data-wrangler-sqlserver")
201+
create_table_sql = (
202+
f"CREATE TABLE dbo.{sqlserver_table} " "(c0 varchar(100) NULL," "c1 INT DEFAULT 42 NULL," "c2 INT NOT NULL);"
203+
)
204+
with con.cursor() as cursor:
205+
cursor.execute(create_table_sql)
206+
con.commit()
207+
208+
df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]})
209+
210+
with pytest.raises(pyodbc.ProgrammingError):
211+
wr.sqlserver.to_sql(df=df, con=con, schema="dbo", table=sqlserver_table, mode="append", use_column_names=False)
212+
213+
wr.sqlserver.to_sql(df=df, con=con, schema="dbo", table=sqlserver_table, mode="append", use_column_names=True)
214+
215+
df2 = wr.sqlserver.read_sql_table(con=con, schema="dbo", table=sqlserver_table)
216+
217+
df["c1"] = 42
218+
df["c0"] = df["c0"].astype("string")
219+
df["c1"] = df["c1"].astype("Int64")
220+
df["c2"] = df["c2"].astype("Int64")
221+
df = df.reindex(sorted(df.columns), axis=1)
222+
assert df.equals(df2)

0 commit comments

Comments
 (0)