Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import annotations

import uuid
import warnings
from typing import TYPE_CHECKING, Any, Protocol

Expand All @@ -27,6 +28,7 @@
except ImportError:
from typing_extensions import deprecated # Python 3.12


import pyarrow as pa

from datafusion.catalog import Catalog
Expand Down Expand Up @@ -592,9 +594,19 @@ def register_listing_table(
self._convert_file_sort_order(file_sort_order),
)

def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
def sql(
self,
query: str,
options: SQLOptions | None = None,
param_values: dict[str, Any] | None = None,
**named_params: Any,
) -> DataFrame:
"""Create a :py:class:`~datafusion.DataFrame` from SQL query text.

See the online documentation for a description of how to perform
parameterized substitution via either the param_values option
or passing in named parameters.

Note: This API implements DDL statements such as ``CREATE TABLE`` and
``CREATE VIEW`` and DML statements such as ``INSERT INTO`` with in-memory
default implementation.See
Expand All @@ -603,15 +615,57 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
Args:
query: SQL query text.
options: If provided, the query will be validated against these options.
param_values: Provides substitution of scalar values in the query
after parsing.
named_params: Provides string or DataFrame substitution in the query string.

Returns:
DataFrame representation of the SQL query.
"""
if options is None:
return DataFrame(self.ctx.sql(query))
return DataFrame(self.ctx.sql_with_options(query, options.options_internal))

def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
def value_to_scalar(value) -> pa.Scalar:
if isinstance(value, pa.Scalar):
return value
return pa.scalar(value)

def value_to_string(value) -> str:
if isinstance(value, DataFrame):
view_name = str(uuid.uuid4()).replace("-", "_")
view_name = f"view_{view_name}"
view = value.df.into_view(temporary=True)
self.ctx.register_table(view_name, view)
return view_name
return str(value)

param_values = (
{name: value_to_scalar(value) for (name, value) in param_values.items()}
if param_values is not None
else {}
)
param_strings = (
{name: value_to_string(value) for (name, value) in named_params.items()}
if named_params is not None
else {}
)

options_raw = options.options_internal if options is not None else None

return DataFrame(
self.ctx.sql_with_options(
query,
options=options_raw,
param_values=param_values,
param_strings=param_strings,
)
)

def sql_with_options(
self,
query: str,
options: SQLOptions,
param_values: dict[str, Any] | None = None,
**named_params: Any,
) -> DataFrame:
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text.

This function will first validate that the query is allowed by the
Expand All @@ -620,11 +674,16 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
Args:
query: SQL query text.
options: SQL options.
param_values: Provides substitution of scalar values in the query
after parsing.
named_params: Provides string or DataFrame substitution in the query string.

Returns:
DataFrame representation of the SQL query.
"""
return self.sql(query, options)
return self.sql(
query, options=options, param_values=param_values, **named_params
)

def create_dataframe(
self,
Expand Down
4 changes: 2 additions & 2 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def __init__(self, df: DataFrameInternal) -> None:
"""
self.df = df

def into_view(self) -> Table:
def into_view(self, temporary: bool = False) -> Table:
"""Convert ``DataFrame`` into a :class:`~datafusion.Table`.

Examples:
Expand All @@ -329,7 +329,7 @@ def into_view(self) -> Table:
"""
from datafusion.catalog import Table as _Table

return _Table(self.df.into_view())
return _Table(self.df.into_view(temporary))

def __getitem__(self, key: str | list[str]) -> DataFrame:
"""Return a new :py:class`DataFrame` with the specified column or columns.
Expand Down
10 changes: 8 additions & 2 deletions python/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,16 @@ def test_register_table_from_dataframe(ctx):
assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]


def test_register_table_from_dataframe_into_view(ctx):
@pytest.mark.parametrize("temporary", [True, False])
def test_register_table_from_dataframe_into_view(ctx, temporary):
df = ctx.from_pydict({"a": [1, 2]})
table = df.into_view()
table = df.into_view(temporary=temporary)
assert isinstance(table, Table)
if temporary:
assert table.kind == "temporary"
else:
assert table.kind == "view"

ctx.register_table("view_tbl", table)
result = ctx.sql("SELECT * FROM view_tbl").collect()
assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]
Expand Down
45 changes: 44 additions & 1 deletion python/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pyarrow as pa
import pyarrow.dataset as ds
import pytest
from datafusion import col, udf
from datafusion import SessionContext, col, udf
from datafusion.object_store import Http
from pyarrow.csv import write_csv

Expand Down Expand Up @@ -533,3 +533,46 @@ def test_register_listing_table(

rd = result.to_pydict()
assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2}


def test_parameterized_named_params(ctx, tmp_path) -> None:
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())

df = ctx.read_parquet(path)
result = ctx.sql(
"SELECT COUNT(a) AS cnt, $lit_val as lit_val FROM $replaced_df",
lit_val=3,
replaced_df=df,
).collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"cnt": [100], "lit_val": [3]}


def test_parameterized_param_values(ctx: SessionContext) -> None:
# Test the parameters that should be handled by the parser rather
# than our manipulation of the query string by searching for tokens
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3, 4])],
names=["a"],
)

ctx.register_record_batches("t", [[batch]])
result = ctx.sql("SELECT a FROM t WHERE a < $val", param_values={"val": 3})
assert result.to_pydict() == {"a": [1, 2]}


def test_parameterized_mixed_query(ctx: SessionContext) -> None:
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3, 4])],
names=["a"],
)
ctx.register_record_batches("t", [[batch]])
registered_df = ctx.table("t")

result = ctx.sql(
"SELECT $col_name FROM $df WHERE a < $val",
param_values={"val": 3},
df=registered_df,
col_name="a",
)
assert result.to_pydict() == {"a": [1, 2]}
40 changes: 28 additions & 12 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use pyo3::exceptions::{PyKeyError, PyValueError};
use pyo3::prelude::*;

use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
use crate::common::data_type::PyScalarValue;
use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
Expand All @@ -40,6 +41,7 @@ use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
use crate::sql::exceptions::py_value_err;
use crate::sql::logical::PyLogicalPlan;
use crate::sql::util::replace_placeholders_with_strings;
use crate::store::StorageContexts;
use crate::table::PyTable;
use crate::udaf::PyAggregateUDF;
Expand Down Expand Up @@ -427,27 +429,41 @@ impl PySessionContext {
self.ctx.register_udtf(&name, func);
}

/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
pub fn sql(&self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
let result = self.ctx.sql(query);
let df = wait_for_future(py, result)??;
Ok(PyDataFrame::new(df))
}

#[pyo3(signature = (query, options=None))]
#[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))]
pub fn sql_with_options(
&self,
query: &str,
options: Option<PySQLOptions>,
py: Python,
mut query: String,
options: Option<PySQLOptions>,
param_values: HashMap<String, PyScalarValue>,
param_strings: HashMap<String, String>,
) -> PyDataFusionResult<PyDataFrame> {
let options = if let Some(options) = options {
options.options
} else {
SQLOptions::new()
};
let result = self.ctx.sql_with_options(query, options);
let df = wait_for_future(py, result)??;

let param_values = param_values
.into_iter()
.map(|(name, value)| (name, ScalarValue::from(value)))
.collect::<HashMap<_, _>>();

let state = self.ctx.state();
let dialect = state.config().options().sql_parser.dialect.as_str();

if !param_strings.is_empty() {
query = replace_placeholders_with_strings(&query, dialect, param_strings)?;
}

let mut df = wait_for_future(py, async {
self.ctx.sql_with_options(&query, options).await
})??;

if !param_values.is_empty() {
df = df.with_param_values(param_values)?;
}

Ok(PyDataFrame::new(df))
}

Expand Down
17 changes: 11 additions & 6 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use arrow::pyarrow::FromPyArrow;
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::arrow::util::pretty;
use datafusion::catalog::TableProvider;
use datafusion::common::UnnestOptions;
use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions};
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
Expand All @@ -47,7 +48,7 @@ use crate::expr::sort_expr::to_sort_expressions;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
use crate::sql::logical::PyLogicalPlan;
use crate::table::PyTable;
use crate::table::{PyTable, TempViewTable};
use crate::utils::{
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
};
Expand Down Expand Up @@ -418,11 +419,15 @@ impl PyDataFrame {
/// because we're working with Python bindings
/// where objects are shared
#[allow(clippy::wrong_self_convention)]
pub fn into_view(&self) -> PyDataFusionResult<PyTable> {
// Call the underlying Rust DataFrame::into_view method.
// Note that the Rust method consumes self; here we clone the inner Arc<DataFrame>
// so that we don't invalidate this PyDataFrame.
let table_provider = self.df.as_ref().clone().into_view();
pub fn into_view(&self, temporary: bool) -> PyDataFusionResult<PyTable> {
let table_provider = if temporary {
Arc::new(TempViewTable::new(Arc::clone(&self.df))) as Arc<dyn TableProvider>
} else {
// Call the underlying Rust DataFrame::into_view method.
// Note that the Rust method consumes self; here we clone the inner Arc<DataFrame>
// so that we don't invalidate this PyDataFrame.
self.df.as_ref().clone().into_view()
};
Ok(PyTable::from(table_provider))
}

Expand Down
1 change: 1 addition & 0 deletions src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@

pub mod exceptions;
pub mod logical;
pub(crate) mod util;
69 changes: 69 additions & 0 deletions src/sql/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use datafusion::common::{exec_err, plan_datafusion_err, DataFusionError};
use datafusion::logical_expr::sqlparser::dialect::dialect_from_str;
use datafusion::sql::sqlparser::dialect::Dialect;
use datafusion::sql::sqlparser::parser::Parser;
use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer};
use std::collections::HashMap;

fn tokens_from_replacements(
placeholder: &str,
replacements: &HashMap<String, Vec<Token>>,
) -> Option<Vec<Token>> {
if let Some(pattern) = placeholder.strip_prefix("$") {
replacements.get(pattern).cloned()
} else {
None
}
}

fn get_tokens_for_string_replacement(
dialect: &dyn Dialect,
replacements: HashMap<String, String>,
) -> Result<HashMap<String, Vec<Token>>, DataFusionError> {
replacements
.into_iter()
.map(|(name, value)| {
let tokens = Tokenizer::new(dialect, &value)
.tokenize()
.map_err(|err| DataFusionError::External(err.into()))?;
Ok((name, tokens))
})
.collect()
}

pub(crate) fn replace_placeholders_with_strings(
query: &str,
dialect: &str,
replacements: HashMap<String, String>,
) -> Result<String, DataFusionError> {
let dialect = dialect_from_str(dialect)
.ok_or_else(|| plan_datafusion_err!("Unsupported SQL dialect: {dialect}."))?;

let replacements = get_tokens_for_string_replacement(dialect.as_ref(), replacements)?;

let tokens = Tokenizer::new(dialect.as_ref(), query)
.tokenize()
.map_err(|err| DataFusionError::External(err.into()))?;

let replaced_tokens = tokens
.into_iter()
.flat_map(|token| {
if let Token::Placeholder(placeholder) = &token {
tokens_from_replacements(placeholder, &replacements).unwrap_or(vec![token])
} else {
vec![token]
}
})
.collect::<Vec<Token>>();

let statement = Parser::new(dialect.as_ref())
.with_tokens(replaced_tokens)
.parse_statements()
.map_err(|err| DataFusionError::External(Box::new(err)))?;

if statement.len() != 1 {
return exec_err!("placeholder replacement should return exactly one statement");
}

Ok(statement[0].to_string())
}
Loading
Loading