Skip to content

Commit 520567b

Browse files
committed
Modified implementation to take in user input as string, not list
1 parent 787f1fa commit 520567b

File tree

2 files changed

+60
-58
lines changed

2 files changed

+60
-58
lines changed

pandas/io/sql.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -231,39 +231,16 @@ def _wrap_result_adbc(
231231
return df
232232

233233

234-
def _process_sql_hints(
235-
hints: dict[str, str | list[str]] | None, dialect_name: str
236-
) -> str | None:
237-
if hints is None or not hints:
234+
def _process_sql_hints(hints: dict[str, str] | None, dialect_name: str) -> str | None:
235+
if hints is None:
238236
return None
239237

240238
dialect_name = dialect_name.lower()
241-
242-
hint_value = None
243239
for key, value in hints.items():
244240
if key.lower() == dialect_name:
245-
hint_value = value
246-
break
247-
248-
if hint_value is None:
249-
return None
241+
return value
250242

251-
if isinstance(hint_value, list):
252-
hint_str = " ".join(hint_value)
253-
else:
254-
hint_str = str(hint_value)
255-
256-
if hint_str.strip().startswith("/*+") and hint_str.strip().endswith("*/"):
257-
return hint_str.strip()
258-
259-
if dialect_name == "oracle":
260-
return f"/*+ {hint_str} */"
261-
elif dialect_name == "mysql":
262-
return hint_str
263-
elif dialect_name == "mssql":
264-
return hint_str
265-
else:
266-
return f"/*+ {hint_str} */"
243+
return None
267244

268245

269246
# -----------------------------------------------------------------------------
@@ -784,7 +761,7 @@ def to_sql(
784761
dtype: DtypeArg | None = None,
785762
method: Literal["multi"] | Callable | None = None,
786763
engine: str = "auto",
787-
hints: dict[str, str | list[str]] | None = None,
764+
hints: dict[str, str] | None = None,
788765
**engine_kwargs,
789766
) -> int | None:
790767
"""
@@ -845,6 +822,23 @@ def to_sql(
845822
846823
.. versionadded:: 1.3.0
847824
825+
hints : dict[str, str], optional
826+
SQL hints to optimize insertion performance, keyed by database dialect name.
827+
Each hint should be a complete string formatted exactly as required by the
828+
target database. The user is responsible for constructing dialect-specific
829+
syntax.
830+
831+
Examples: ``{'oracle': '/*+ APPEND PARALLEL(4) */'}``
832+
``{'mysql': 'DELAYED'}``
833+
``{'mssql': 'WITH (TABLOCK)'}``
834+
835+
.. note::
836+
- Hints are database-specific and will be ignored for unsupported dialects
837+
- SQLite will raise a UserWarning (hints not supported)
838+
- ADBC connections will raise NotImplementedError
839+
840+
.. versionadded::
841+
848842
**engine_kwargs
849843
Any additional kwargs are passed to the engine.
850844
@@ -1142,7 +1136,7 @@ def insert(
11421136
self,
11431137
chunksize: int | None = None,
11441138
method: Literal["multi"] | Callable | None = None,
1145-
hints: dict[str, str | list[str]] | None = None,
1139+
hints: dict[str, str] | None = None,
11461140
dialect_name: str | None = None,
11471141
) -> int | None:
11481142
# set insert method
@@ -1570,7 +1564,7 @@ def to_sql(
15701564
chunksize: int | None = None,
15711565
dtype: DtypeArg | None = None,
15721566
method: Literal["multi"] | Callable | None = None,
1573-
hints: dict[str, str | list[str]] | None = None,
1567+
hints: dict[str, str] | None = None,
15741568
engine: str = "auto",
15751569
**engine_kwargs,
15761570
) -> int | None:
@@ -1607,7 +1601,7 @@ def insert_records(
16071601
schema=None,
16081602
chunksize: int | None = None,
16091603
method=None,
1610-
hints: dict[str, str | list[str]] | None = None,
1604+
hints: dict[str, str] | None = None,
16111605
dialect_name: str | None = None,
16121606
**engine_kwargs,
16131607
) -> int | None:
@@ -1633,7 +1627,7 @@ def insert_records(
16331627
schema=None,
16341628
chunksize: int | None = None,
16351629
method=None,
1636-
hints: dict[str, str | list[str]] | None = None,
1630+
hints: dict[str, str] | None = None,
16371631
dialect_name: str | None = None,
16381632
**engine_kwargs,
16391633
) -> int | None:
@@ -2047,7 +2041,7 @@ def to_sql(
20472041
dtype: DtypeArg | None = None,
20482042
method: Literal["multi"] | Callable | None = None,
20492043
engine: str = "auto",
2050-
hints: dict[str, str | list[str]] | None = None,
2044+
hints: dict[str, str] | None = None,
20512045
**engine_kwargs,
20522046
) -> int | None:
20532047
"""
@@ -2414,7 +2408,7 @@ def to_sql(
24142408
dtype: DtypeArg | None = None,
24152409
method: Literal["multi"] | Callable | None = None,
24162410
engine: str = "auto",
2417-
hints: dict[str, str | list[str]] | None = None,
2411+
hints: dict[str, str] | None = None,
24182412
**engine_kwargs,
24192413
) -> int | None:
24202414
"""
@@ -2894,7 +2888,7 @@ def to_sql(
28942888
dtype: DtypeArg | None = None,
28952889
method: Literal["multi"] | Callable | None = None,
28962890
engine: str = "auto",
2897-
hints: dict[str, str | list[str]] | None = None,
2891+
hints: dict[str, str] | None = None,
28982892
**engine_kwargs,
28992893
) -> int | None:
29002894
"""

pandas/tests/io/test_sql.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4407,31 +4407,25 @@ def test_xsqlite_if_exists(sqlite_buildin):
44074407
class TestProcessSQLHints:
44084408
"""Tests for _process_sql_hints helper function."""
44094409

4410-
def test_process_sql_hints_oracle_list(self):
4411-
"""Test hint processing with Oracle dialect and list input."""
4412-
hints = {"oracle": ["APPEND", "PARALLEL"]}
4413-
result = sql._process_sql_hints(hints, "oracle")
4414-
assert result == "/*+ APPEND PARALLEL */"
4415-
44164410
def test_process_sql_hints_oracle_string(self):
4417-
"""Test hint processing with Oracle dialect and string input."""
4418-
hints = {"oracle": "APPEND PARALLEL"}
4411+
"""Test hint processing with Oracle dialect - user provides complete string."""
4412+
hints = {"oracle": "/*+ APPEND PARALLEL */"}
44194413
result = sql._process_sql_hints(hints, "oracle")
44204414
assert result == "/*+ APPEND PARALLEL */"
44214415

4422-
def test_process_sql_hints_preformatted(self):
4423-
"""Test that pre-formatted hints are returned as-is."""
4424-
hints = {"oracle": "/*+ APPEND PARALLEL */"}
4416+
def test_process_sql_hints_oracle_simple(self):
4417+
"""Test hint processing with simple Oracle hint string."""
4418+
hints = {"oracle": "/*+ PARALLEL */"}
44254419
result = sql._process_sql_hints(hints, "oracle")
4426-
assert result == "/*+ APPEND PARALLEL */"
4420+
assert result == "/*+ PARALLEL */"
44274421

44284422
def test_process_sql_hints_case_insensitive(self):
44294423
"""Test that dialect names are case-insensitive."""
4430-
hints = {"ORACLE": ["APPEND"]}
4424+
hints = {"ORACLE": "/*+ APPEND */"}
44314425
result = sql._process_sql_hints(hints, "oracle")
44324426
assert result == "/*+ APPEND */"
44334427

4434-
hints = {"oracle": ["APPEND"]}
4428+
hints = {"oracle": "/*+ APPEND */"}
44354429
result = sql._process_sql_hints(hints, "ORACLE")
44364430
assert result == "/*+ APPEND */"
44374431

@@ -4459,9 +4453,20 @@ def test_process_sql_hints_mysql(self):
44594453

44604454
def test_process_sql_hints_mssql(self):
44614455
"""Test hint processing for SQL Server dialect."""
4462-
hints = {"mssql": "TABLOCK"}
4456+
hints = {"mssql": "WITH (TABLOCK)"}
44634457
result = sql._process_sql_hints(hints, "mssql")
4464-
assert result == "TABLOCK"
4458+
assert result == "WITH (TABLOCK)"
4459+
4460+
def test_process_sql_hints_multiple_dialects(self):
4461+
"""Test extraction from dict with multiple dialects."""
4462+
hints = {
4463+
"oracle": "/*+ PARALLEL */",
4464+
"mysql": "DELAYED",
4465+
"postgresql": "/* comment */",
4466+
}
4467+
assert sql._process_sql_hints(hints, "oracle") == "/*+ PARALLEL */"
4468+
assert sql._process_sql_hints(hints, "mysql") == "DELAYED"
4469+
assert sql._process_sql_hints(hints, "postgresql") == "/* comment */"
44654470

44664471

44674472
@pytest.mark.parametrize("conn", sqlalchemy_connectable)
@@ -4471,7 +4476,10 @@ def test_to_sql_with_hints_parameter(conn, test_frame1, request):
44714476

44724477
with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
44734478
pandasSQL.to_sql(
4474-
test_frame1, "test_hints", hints={"oracle": ["APPEND"]}, if_exists="replace"
4479+
test_frame1,
4480+
"test_hints",
4481+
hints={"oracle": "/*+ APPEND */"},
4482+
if_exists="replace",
44754483
)
44764484
assert pandasSQL.has_table("test_hints")
44774485
assert count_rows(conn, "test_hints") == len(test_frame1)
@@ -4505,7 +4513,7 @@ def sample(pd_table, conn, keys, data_iter):
45054513
test_frame1,
45064514
"test_hints_method",
45074515
method=sample,
4508-
hints={"oracle": ["APPEND"]},
4516+
hints={"oracle": "/*+ APPEND */"},
45094517
)
45104518
assert pandasSQL.has_table("test_hints_method")
45114519

@@ -4524,7 +4532,7 @@ def test_to_sql_hints_with_different_methods(conn, method, test_frame1, request)
45244532
test_frame1,
45254533
"test_hints_methods",
45264534
method=method,
4527-
hints={"oracle": ["APPEND", "PARALLEL"]},
4535+
hints={"oracle": "/*+ APPEND PARALLEL */"},
45284536
if_exists="replace",
45294537
)
45304538
assert pandasSQL.has_table("test_hints_methods")
@@ -4538,10 +4546,10 @@ def test_to_sql_hints_multidb_dict(conn, test_frame1, request):
45384546
conn = request.getfixturevalue(conn)
45394547

45404548
hints = {
4541-
"oracle": ["APPEND", "PARALLEL"],
4549+
"oracle": "/*+ APPEND PARALLEL */",
45424550
"mysql": "HIGH_PRIORITY",
4543-
"postgresql": "some_pg_hint",
4544-
"sqlite": "ignored",
4551+
"postgresql": "/* pg hint */",
4552+
"sqlite": "IGNORED",
45454553
}
45464554

45474555
with pandasSQL_builder(conn, need_transaction=True) as pandasSQL:
@@ -4561,7 +4569,7 @@ def test_to_sql_hints_adbc_not_supported(sqlite_adbc_conn, test_frame1):
45614569
msg = "'hints' is not implemented for ADBC drivers"
45624570

45634571
with pytest.raises(NotImplementedError, match=msg):
4564-
df.to_sql("test", sqlite_adbc_conn, hints={"oracle": ["APPEND"]})
4572+
df.to_sql("test", sqlite_adbc_conn, hints={"mysql": "SOME_HINT"})
45654573

45664574

45674575
def test_to_sql_hints_sqlite_builtin(sqlite_buildin, test_frame1):

0 commit comments

Comments
 (0)