Skip to content

Commit a92643b

Browse files
committed
Add additional unit tests for parameterized queries
1 parent 6f2b49e commit a92643b

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

python/datafusion/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def value_to_string(value) -> str:
638638
return str(value)
639639

640640
param_values = (
641-
{name: value_to_scalar(value) for (name, value) in param_values}
641+
{name: value_to_scalar(value) for (name, value) in param_values.items()}
642642
if param_values is not None
643643
else {}
644644
)

python/tests/test_sql.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -552,18 +552,20 @@ def test_register_listing_table(
552552
assert dict(zip(rd["grp"], rd["count"], strict=False)) == {"a": 3, "b": 2}
553553

554554

555-
def test_parameterized_df_in_sql(ctx, tmp_path) -> None:
555+
def test_parameterized_named_params(ctx, tmp_path) -> None:
556556
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
557557

558558
df = ctx.read_parquet(path)
559559
result = ctx.sql(
560-
"SELECT COUNT(a) AS cnt FROM $replaced_df", replaced_df=df
560+
"SELECT COUNT(a) AS cnt, $lit_val as lit_val FROM $replaced_df",
561+
lit_val=3,
562+
replaced_df=df,
561563
).collect()
562564
result = pa.Table.from_batches(result)
563-
assert result.to_pydict() == {"cnt": [100]}
565+
assert result.to_pydict() == {"cnt": [100], "lit_val": [3]}
564566

565567

566-
def test_parameterized_pass_through_in_sql(ctx: SessionContext) -> None:
568+
def test_parameterized_param_values(ctx: SessionContext) -> None:
567569
# Test the parameters that should be handled by the parser rather
568570
# than our manipulation of the query string by searching for tokens
569571
batch = pa.RecordBatch.from_arrays(
@@ -572,5 +574,22 @@ def test_parameterized_pass_through_in_sql(ctx: SessionContext) -> None:
572574
)
573575

574576
ctx.register_record_batches("t", [[batch]])
575-
result = ctx.sql("SELECT a FROM t WHERE a < $val", val=3)
577+
result = ctx.sql("SELECT a FROM t WHERE a < $val", param_values={"val": 3})
578+
assert result.to_pydict() == {"a": [1, 2]}
579+
580+
581+
def test_parameterized_mixed_query(ctx: SessionContext) -> None:
582+
batch = pa.RecordBatch.from_arrays(
583+
[pa.array([1, 2, 3, 4])],
584+
names=["a"],
585+
)
586+
ctx.register_record_batches("t", [[batch]])
587+
registered_df = ctx.table("t")
588+
589+
result = ctx.sql(
590+
"SELECT $col_name FROM $df WHERE a < $val",
591+
param_values={"val": 3},
592+
df=registered_df,
593+
col_name="a",
594+
)
576595
assert result.to_pydict() == {"a": [1, 2]}

0 commit comments

Comments
 (0)