Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/user-guide/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
.. specific language governing permissions and limitations
.. under the License.

.. _configuration:

Configuration
=============

Expand Down
91 changes: 87 additions & 4 deletions docs/source/user-guide/sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,100 @@ DataFusion also offers a SQL API, read the full reference `here <https://arrow.a
.. ipython:: python

import datafusion
from datafusion import col
import pyarrow
from datafusion import DataFrame, SessionContext

# create a context
ctx = datafusion.SessionContext()

# register a CSV
ctx.register_csv('pokemon', 'pokemon.csv')
ctx.register_csv("pokemon", "pokemon.csv")

# create a new statement via SQL
df = ctx.sql('SELECT "Attack"+"Defense", "Attack"-"Defense" FROM pokemon')

# collect and convert to pandas DataFrame
df.to_pandas()
df.to_pandas()

Parameterized queries
---------------------

In DataFusion-Python 51.0.0 we introduced the ability to pass parameters
in a SQL query. These are similar in concept to
`prepared statements <https://datafusion.apache.org/user-guide/sql/prepared_statements.html>`_,
but allow passing named parameters into a SQL query. Consider this simple
example.

.. ipython:: python

def show_attacks(ctx: SessionContext, threshold: int) -> None:
ctx.sql(
'SELECT "Name", "Attack" FROM pokemon WHERE "Attack" > $val', val=threshold
).show(num=5)
show_attacks(ctx, 75)

When passing parameters like the example above we convert the Python objects
into their string representation. We also have special case handling
for :py:class:`~datafusion.dataframe.DataFrame` objects, since they cannot simply
be turned into string representations for an SQL query. In these cases we
will register a temporary view in the :py:class:`~datafusion.context.SessionContext`
using a generated table name.

The formatting for passing string replacement objects is to precede the
variable name with a single ``$``. This works for all dialects in
the SQL parser except ``hive`` and ``mysql``. Since these dialects do not
support named placeholders, we are unable to do this type of replacement.
We recommend either switching to another dialect or using Python
f-string style replacement.

.. warning::

To support DataFrame parameterized queries, your session must support
registration of temporary views. The default
:py:class:`~datafusion.catalog.CatalogProvider` and
:py:class:`~datafusion.catalog.SchemaProvider` do have this capability.
If you have implemented custom providers, it is important that temporary
views do not persist across :py:class:`~datafusion.context.SessionContext`
or you may get unintended consequences.

The following example shows passing in both a :py:class:`~datafusion.dataframe.DataFrame`
object as well as a Python object to be used in parameterized replacement.

.. ipython:: python

def show_column(
ctx: SessionContext, column: str, df: DataFrame, threshold: int
) -> None:
ctx.sql(
'SELECT "Name", $col FROM $df WHERE $col > $val',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another mechanism you could use instead of a parameter here is what DuckDB would call a replacement scan...in DuckDB you can select * from foo where foo can be a variable in the Python frame from whence .sql() was called. I believe you could implement that as a Catalog where fetching a table of name foo would look for a Python variable.

col=column,
df=df,
val=threshold,
).show(num=5)
df = ctx.table("pokemon")
show_column(ctx, '"Defense"', df, 75)

The approach implemented for conversion of variables into a SQL query
relies on string conversion. This has the potential for data loss,
specifically for cases like floating point numbers. If you need to pass
variables into a parameterized query and it is important to maintain the
original value without conversion to a string, then you can use the
optional parameter ``param_values`` to specify these. This parameter
expects a dictionary mapping from the parameter name to a Python
object. Those objects will be cast into a
`PyArrow Scalar Value <https://arrow.apache.org/docs/python/generated/pyarrow.Scalar.html>`_.

Using ``param_values`` will rely on the SQL dialect you have configured
for your session. This can be set using the :ref:`configuration options <configuration>`
of your :py:class:`~datafusion.context.SessionContext`. Similar to how
`prepared statements <https://datafusion.apache.org/user-guide/sql/prepared_statements.html>`_
work, these parameters are limited to places where you would pass in a
scalar value, such as a comparison.

.. ipython:: python

def param_attacks(ctx: SessionContext, threshold: int) -> None:
ctx.sql(
'SELECT "Name", "Attack" FROM pokemon WHERE "Attack" > $val',
param_values={"val": threshold},
).show(num=5)
param_attacks(ctx, 75)
71 changes: 65 additions & 6 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import annotations

import uuid
import warnings
from typing import TYPE_CHECKING, Any, Protocol

Expand All @@ -27,6 +28,7 @@
except ImportError:
from typing_extensions import deprecated # Python 3.12


import pyarrow as pa

from datafusion.catalog import Catalog
Expand Down Expand Up @@ -592,9 +594,19 @@ def register_listing_table(
self._convert_file_sort_order(file_sort_order),
)

def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
def sql(
self,
query: str,
options: SQLOptions | None = None,
param_values: dict[str, Any] | None = None,
**named_params: Any,
) -> DataFrame:
"""Create a :py:class:`~datafusion.DataFrame` from SQL query text.

See the online documentation for a description of how to perform
parameterized substitution via either the param_values option
or passing in named parameters.

Note: This API implements DDL statements such as ``CREATE TABLE`` and
``CREATE VIEW`` and DML statements such as ``INSERT INTO`` with in-memory
default implementation.See
Expand All @@ -603,15 +615,57 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
Args:
query: SQL query text.
options: If provided, the query will be validated against these options.
param_values: Provides substitution of scalar values in the query
after parsing.
named_params: Provides string or DataFrame substitution in the query string.

Returns:
DataFrame representation of the SQL query.
"""
if options is None:
return DataFrame(self.ctx.sql(query))
return DataFrame(self.ctx.sql_with_options(query, options.options_internal))

def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
def value_to_scalar(value) -> pa.Scalar:
if isinstance(value, pa.Scalar):
return value
return pa.scalar(value)

def value_to_string(value) -> str:
if isinstance(value, DataFrame):
view_name = str(uuid.uuid4()).replace("-", "_")
view_name = f"view_{view_name}"
view = value.df.into_view(temporary=True)
self.ctx.register_table(view_name, view)
return view_name
return str(value)

param_values = (
{name: value_to_scalar(value) for (name, value) in param_values.items()}
if param_values is not None
else {}
)
param_strings = (
{name: value_to_string(value) for (name, value) in named_params.items()}
if named_params is not None
else {}
)

options_raw = options.options_internal if options is not None else None

return DataFrame(
self.ctx.sql_with_options(
query,
options=options_raw,
param_values=param_values,
param_strings=param_strings,
)
)

def sql_with_options(
self,
query: str,
options: SQLOptions,
param_values: dict[str, Any] | None = None,
**named_params: Any,
) -> DataFrame:
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text.

This function will first validate that the query is allowed by the
Expand All @@ -620,11 +674,16 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
Args:
query: SQL query text.
options: SQL options.
param_values: Provides substitution of scalar values in the query
after parsing.
named_params: Provides string or DataFrame substitution in the query string.

Returns:
DataFrame representation of the SQL query.
"""
return self.sql(query, options)
return self.sql(
query, options=options, param_values=param_values, **named_params
)

def create_dataframe(
self,
Expand Down
45 changes: 44 additions & 1 deletion python/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pyarrow as pa
import pyarrow.dataset as ds
import pytest
from datafusion import col, udf
from datafusion import SessionContext, col, udf
from datafusion.object_store import Http
from pyarrow.csv import write_csv

Expand Down Expand Up @@ -533,3 +533,46 @@ def test_register_listing_table(

rd = result.to_pydict()
assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2}


def test_parameterized_named_params(ctx, tmp_path) -> None:
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())

df = ctx.read_parquet(path)
result = ctx.sql(
"SELECT COUNT(a) AS cnt, $lit_val as lit_val FROM $replaced_df",
lit_val=3,
replaced_df=df,
).collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"cnt": [100], "lit_val": [3]}


def test_parameterized_param_values(ctx: SessionContext) -> None:
# Test the parameters that should be handled by the parser rather
# than our manipulation of the query string by searching for tokens
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3, 4])],
names=["a"],
)

ctx.register_record_batches("t", [[batch]])
result = ctx.sql("SELECT a FROM t WHERE a < $val", param_values={"val": 3})
assert result.to_pydict() == {"a": [1, 2]}


def test_parameterized_mixed_query(ctx: SessionContext) -> None:
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3, 4])],
names=["a"],
)
ctx.register_record_batches("t", [[batch]])
registered_df = ctx.table("t")

result = ctx.sql(
"SELECT $col_name FROM $df WHERE a < $val",
param_values={"val": 3},
df=registered_df,
col_name="a",
)
assert result.to_pydict() == {"a": [1, 2]}
40 changes: 28 additions & 12 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use pyo3::exceptions::{PyKeyError, PyValueError};
use pyo3::prelude::*;

use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
use crate::common::data_type::PyScalarValue;
use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
Expand All @@ -40,6 +41,7 @@ use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
use crate::sql::exceptions::py_value_err;
use crate::sql::logical::PyLogicalPlan;
use crate::sql::util::replace_placeholders_with_strings;
use crate::store::StorageContexts;
use crate::table::PyTable;
use crate::udaf::PyAggregateUDF;
Expand Down Expand Up @@ -427,27 +429,41 @@ impl PySessionContext {
self.ctx.register_udtf(&name, func);
}

/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
pub fn sql(&self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
let result = self.ctx.sql(query);
let df = wait_for_future(py, result)??;
Ok(PyDataFrame::new(df))
}

#[pyo3(signature = (query, options=None))]
#[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))]
pub fn sql_with_options(
&self,
query: &str,
options: Option<PySQLOptions>,
py: Python,
mut query: String,
options: Option<PySQLOptions>,
param_values: HashMap<String, PyScalarValue>,
param_strings: HashMap<String, String>,
) -> PyDataFusionResult<PyDataFrame> {
let options = if let Some(options) = options {
options.options
} else {
SQLOptions::new()
};
let result = self.ctx.sql_with_options(query, options);
let df = wait_for_future(py, result)??;

let param_values = param_values
.into_iter()
.map(|(name, value)| (name, ScalarValue::from(value)))
.collect::<HashMap<_, _>>();

let state = self.ctx.state();
let dialect = state.config().options().sql_parser.dialect.as_str();

if !param_strings.is_empty() {
query = replace_placeholders_with_strings(&query, dialect, param_strings)?;
}

let mut df = wait_for_future(py, async {
self.ctx.sql_with_options(&query, options).await
})??;

if !param_values.is_empty() {
df = df.with_param_values(param_values)?;
}

Ok(PyDataFrame::new(df))
}

Expand Down
1 change: 1 addition & 0 deletions src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@

pub mod exceptions;
pub mod logical;
pub(crate) mod util;
Loading