Skip to content

Commit 623411d

Browse files
committed
Completed implementation
1 parent 6e27e26 commit 623411d

File tree

2 files changed

+100
-12
lines changed

2 files changed

+100
-12
lines changed

pandas/core/generic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2798,6 +2798,7 @@ def to_sql(
27982798
chunksize: int | None = None,
27992799
dtype: DtypeArg | None = None,
28002800
method: Literal["multi"] | Callable | None = None,
2801+
hints: dict[str, str | list[str]] | None = None,
28012802
) -> int | None:
28022803
"""
28032804
Write records stored in a DataFrame to a SQL database.
@@ -3044,6 +3045,7 @@ def to_sql(
30443045
chunksize=chunksize,
30453046
dtype=dtype,
30463047
method=method,
3048+
hints=hints,
30473049
)
30483050

30493051
@final

pandas/io/sql.py

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
datetime,
1919
time,
2020
)
21-
from functools import partial
2221
import re
2322
from typing import (
2423
TYPE_CHECKING,
@@ -232,6 +231,41 @@ def _wrap_result_adbc(
232231
return df
233232

234233

234+
def _process_sql_hints(
235+
hints: dict[str, str | list[str]] | None, dialect_name: str
236+
) -> str | None:
237+
if hints is None or not hints:
238+
return None
239+
240+
dialect_name = dialect_name.lower()
241+
242+
hint_value = None
243+
for key, value in hints.items():
244+
if key.lower() == dialect_name:
245+
hint_value = value
246+
break
247+
248+
if hint_value is None:
249+
return None
250+
251+
if isinstance(hint_value, list):
252+
hint_str = " ".join(hint_value)
253+
else:
254+
hint_str = str(hint_value)
255+
256+
if hint_str.strip().startswith("/*+") and hint_str.strip().endswith("*/"):
257+
return hint_str.strip()
258+
259+
if dialect_name == "oracle":
260+
return f"/*+ {hint_str} */"
261+
elif dialect_name == "mysql":
262+
return hint_str
263+
elif dialect_name == "mssql":
264+
return hint_str
265+
else:
266+
return f"/*+ {hint_str} */"
267+
268+
235269
# -----------------------------------------------------------------------------
236270
# -- Read and write to DataFrames
237271

@@ -750,6 +784,7 @@ def to_sql(
750784
dtype: DtypeArg | None = None,
751785
method: Literal["multi"] | Callable | None = None,
752786
engine: str = "auto",
787+
hints: dict[str, str | list[str]] | None = None,
753788
**engine_kwargs,
754789
) -> int | None:
755790
"""
@@ -852,6 +887,7 @@ def to_sql(
852887
dtype=dtype,
853888
method=method,
854889
engine=engine,
890+
hints=hints,
855891
**engine_kwargs,
856892
)
857893

@@ -998,7 +1034,13 @@ def create(self) -> None:
9981034
else:
9991035
self._execute_create()
10001036

1001-
def _execute_insert(self, conn, keys: list[str], data_iter) -> int:
1037+
def _execute_insert(
1038+
self,
1039+
conn,
1040+
keys: list[str],
1041+
data_iter,
1042+
hint_str: str | None = None,
1043+
) -> int:
10021044
"""
10031045
Execute SQL statement inserting data
10041046
@@ -1010,11 +1052,23 @@ def _execute_insert(self, conn, keys: list[str], data_iter) -> int:
10101052
data_iter : generator of list
10111053
Each item contains a list of values to be inserted
10121054
"""
1013-
data = [dict(zip(keys, row, strict=True)) for row in data_iter]
1014-
result = self.pd_sql.execute(self.table.insert(), data)
1055+
data = [dict(zip(keys, row, strict=False)) for row in data_iter]
1056+
1057+
if hint_str:
1058+
stmt = self.table.insert().prefix_with(hint_str)
1059+
else:
1060+
stmt = self.table.insert()
1061+
1062+
result = self.pd_sql.execute(stmt, data)
10151063
return result.rowcount
10161064

1017-
def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int:
1065+
def _execute_insert_multi(
1066+
self,
1067+
conn,
1068+
keys: list[str],
1069+
data_iter,
1070+
hint_str: str | None = None,
1071+
) -> int:
10181072
"""
10191073
Alternative to _execute_insert for DBs support multi-value INSERT.
10201074
@@ -1023,11 +1077,15 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int:
10231077
but performance degrades quickly with increase of columns.
10241078
10251079
"""
1026-
10271080
from sqlalchemy import insert
10281081

1029-
data = [dict(zip(keys, row, strict=True)) for row in data_iter]
1030-
stmt = insert(self.table).values(data)
1082+
data = [dict(zip(keys, row, strict=False)) for row in data_iter]
1083+
1084+
if hint_str:
1085+
stmt = insert(self.table).values(data).prefix_with(hint_str)
1086+
else:
1087+
stmt = insert(self.table).values(data)
1088+
10311089
result = self.pd_sql.execute(stmt)
10321090
return result.rowcount
10331091

@@ -1084,14 +1142,20 @@ def insert(
10841142
self,
10851143
chunksize: int | None = None,
10861144
method: Literal["multi"] | Callable | None = None,
1145+
hints: dict[str, str | list[str]] | None = None,
1146+
dialect_name: str | None = None,
10871147
) -> int | None:
10881148
# set insert method
10891149
if method is None:
10901150
exec_insert = self._execute_insert
10911151
elif method == "multi":
10921152
exec_insert = self._execute_insert_multi
10931153
elif callable(method):
1094-
exec_insert = partial(method, self)
1154+
1155+
def callable_wrapper(conn, keys, data_iter, hint_str=None):
1156+
return method(self, conn, keys, data_iter)
1157+
1158+
exec_insert = callable_wrapper
10951159
else:
10961160
raise ValueError(f"Invalid parameter `method`: {method}")
10971161

@@ -1108,6 +1172,9 @@ def insert(
11081172
raise ValueError("chunksize argument should be non-zero")
11091173

11101174
chunks = (nrows // chunksize) + 1
1175+
1176+
hint_str = _process_sql_hints(hints, dialect_name) if dialect_name else None
1177+
11111178
total_inserted = None
11121179
with self.pd_sql.run_transaction() as conn:
11131180
for i in range(chunks):
@@ -1119,7 +1186,7 @@ def insert(
11191186
chunk_iter = zip(
11201187
*(arr[start_i:end_i] for arr in data_list), strict=True
11211188
)
1122-
num_inserted = exec_insert(conn, keys, chunk_iter)
1189+
num_inserted = exec_insert(conn, keys, chunk_iter, hint_str)
11231190
# GH 46891
11241191
if num_inserted is not None:
11251192
if total_inserted is None:
@@ -1503,6 +1570,7 @@ def to_sql(
15031570
chunksize: int | None = None,
15041571
dtype: DtypeArg | None = None,
15051572
method: Literal["multi"] | Callable | None = None,
1573+
hints: dict[str, str | list[str]] | None = None,
15061574
engine: str = "auto",
15071575
**engine_kwargs,
15081576
) -> int | None:
@@ -1539,6 +1607,8 @@ def insert_records(
15391607
schema=None,
15401608
chunksize: int | None = None,
15411609
method=None,
1610+
hints: dict[str, str | list[str]] | None = None,
1611+
dialect_name: str | None = None,
15421612
**engine_kwargs,
15431613
) -> int | None:
15441614
"""
@@ -1563,6 +1633,8 @@ def insert_records(
15631633
schema=None,
15641634
chunksize: int | None = None,
15651635
method=None,
1636+
hints: dict[str, str | list[str]] | None = None,
1637+
dialect_name: str | None = None,
15661638
**engine_kwargs,
15671639
) -> int | None:
15681640
from sqlalchemy import exc
@@ -1975,6 +2047,7 @@ def to_sql(
19752047
dtype: DtypeArg | None = None,
19762048
method: Literal["multi"] | Callable | None = None,
19772049
engine: str = "auto",
2050+
hints: dict[str, str | list[str]] | None = None,
19782051
**engine_kwargs,
19792052
) -> int | None:
19802053
"""
@@ -2047,6 +2120,8 @@ def to_sql(
20472120
schema=schema,
20482121
chunksize=chunksize,
20492122
method=method,
2123+
hints=hints,
2124+
dialect_name=self.con.dialect.name,
20502125
**engine_kwargs,
20512126
)
20522127

@@ -2339,6 +2414,7 @@ def to_sql(
23392414
dtype: DtypeArg | None = None,
23402415
method: Literal["multi"] | Callable | None = None,
23412416
engine: str = "auto",
2417+
hints: dict[str, str | list[str]] | None = None,
23422418
**engine_kwargs,
23432419
) -> int | None:
23442420
"""
@@ -2388,6 +2464,8 @@ def to_sql(
23882464
raise NotImplementedError(
23892465
"engine != 'auto' not implemented for ADBC drivers"
23902466
)
2467+
if hints:
2468+
raise NotImplementedError("'hints' is not implemented for ADBC drivers")
23912469

23922470
if schema:
23932471
table_name = f"{schema}.{name}"
@@ -2569,7 +2647,7 @@ def insert_statement(self, *, num_rows: int) -> str:
25692647
)
25702648
return insert_statement
25712649

2572-
def _execute_insert(self, conn, keys, data_iter) -> int:
2650+
def _execute_insert(self, conn, keys, data_iter, hints) -> int:
25732651
from sqlite3 import Error
25742652

25752653
data_list = list(data_iter)
@@ -2579,7 +2657,7 @@ def _execute_insert(self, conn, keys, data_iter) -> int:
25792657
raise DatabaseError("Execution failed") from exc
25802658
return conn.rowcount
25812659

2582-
def _execute_insert_multi(self, conn, keys, data_iter) -> int:
2660+
def _execute_insert_multi(self, conn, keys, data_iter, hints) -> int:
25832661
data_list = list(data_iter)
25842662
flattened_data = [x for row in data_list for x in row]
25852663
conn.execute(self.insert_statement(num_rows=len(data_list)), flattened_data)
@@ -2816,6 +2894,7 @@ def to_sql(
28162894
dtype: DtypeArg | None = None,
28172895
method: Literal["multi"] | Callable | None = None,
28182896
engine: str = "auto",
2897+
hints: dict[str, str | list[str]] | None = None,
28192898
**engine_kwargs,
28202899
) -> int | None:
28212900
"""
@@ -2857,6 +2936,13 @@ def to_sql(
28572936
Details and a sample callable implementation can be found in the
28582937
section :ref:`insert method <io.sql.method>`.
28592938
"""
2939+
if hints:
2940+
warnings.warn(
2941+
"SQL hints are not supported for SQLite and will be ignored.",
2942+
UserWarning,
2943+
stacklevel=find_stack_level(),
2944+
)
2945+
28602946
if dtype:
28612947
if not is_dict_like(dtype):
28622948
# error: Value expression in dictionary comprehension has incompatible

0 commit comments

Comments
 (0)