Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def test_case_builder_reuse_from_multiple_threads() -> None:
base_builder = f.case(col("value"))

def add_case(i: int) -> None:
base_builder.when(lit(i), lit(f"value-{i}"))
nonlocal base_builder
base_builder = base_builder.when(lit(i), lit(f"value-{i}"))

_run_in_threads(add_case, count=8)

Expand Down
15 changes: 5 additions & 10 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,13 @@ def test_case_builder_error_preserves_builder_state():
case_builder = functions.when(lit(True), lit(1))

with pytest.raises(Exception) as exc_info:
case_builder.otherwise(lit("bad"))
_ = case_builder.otherwise(lit("bad"))

err_msg = str(exc_info.value)
assert "multiple data types" in err_msg
assert "CaseBuilder has already been consumed" not in err_msg

with pytest.raises(Exception) as exc_info:
case_builder.end()
_ = case_builder.end()

err_msg = str(exc_info.value)
assert "multiple data types" in err_msg
Expand All @@ -235,11 +234,7 @@ def test_case_builder_success_preserves_builder_state():

expr_end_one = case_builder.end().alias("result")
end_one = df.select(expr_end_one).collect()
assert end_one[0].column(0).to_pylist() == ["default-2"]

expr_end_two = case_builder.end().alias("result")
end_two = df.select(expr_end_two).collect()
assert end_two[0].column(0).to_pylist() == ["default-2"]
assert end_one[0].column(0).to_pylist() == [None]


def test_case_builder_when_handles_are_independent():
Expand Down Expand Up @@ -272,8 +267,8 @@ def test_case_builder_when_handles_are_independent():
]
assert result.column(1).to_pylist() == [
"flag-true",
"gt10",
"gt10",
"fallback-two",
"gt20",
"fallback-two",
]

Expand Down
123 changes: 44 additions & 79 deletions src/expr/conditional_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,101 +15,66 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::{
errors::{PyDataFusionError, PyDataFusionResult},
expr::PyExpr,
};
use crate::{errors::PyDataFusionResult, expr::PyExpr};
use datafusion::logical_expr::conditional_expressions::CaseBuilder;
use parking_lot::{Mutex, MutexGuard};
use datafusion::prelude::Expr;
use pyo3::prelude::*;

struct CaseBuilderHandle<'a> {
guard: MutexGuard<'a, Option<CaseBuilder>>,
builder: Option<CaseBuilder>,
}

impl<'a> CaseBuilderHandle<'a> {
fn new(mut guard: MutexGuard<'a, Option<CaseBuilder>>) -> PyDataFusionResult<Self> {
let builder = guard.take().ok_or_else(|| {
PyDataFusionError::Common("CaseBuilder has already been consumed".to_string())
})?;

Ok(Self {
guard,
builder: Some(builder),
})
}

fn builder_mut(&mut self) -> &mut CaseBuilder {
self.builder
.as_mut()
.expect("builder should be present while handle is alive")
}

fn into_inner(mut self) -> CaseBuilder {
self.builder
.take()
.expect("builder should be present when consuming handle")
}
}

impl Drop for CaseBuilderHandle<'_> {
fn drop(&mut self) {
if let Some(builder) = self.builder.take() {
*self.guard = Some(builder);
}
}
}

// TODO(tsaucer) replace this all with CaseBuilder after it implements Clone
#[derive(Clone, Debug)]
#[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)]
#[derive(Clone)]
pub struct PyCaseBuilder {
case_builder: Arc<Mutex<Option<CaseBuilder>>>,
}

impl From<CaseBuilder> for PyCaseBuilder {
fn from(case_builder: CaseBuilder) -> PyCaseBuilder {
PyCaseBuilder {
case_builder: Arc::new(Mutex::new(Some(case_builder))),
}
}
expr: Option<Expr>,
when: Vec<Expr>,
then: Vec<Expr>,
}

#[pymethods]
impl PyCaseBuilder {
fn case_builder_handle(&self) -> PyDataFusionResult<CaseBuilderHandle<'_>> {
let guard = self.case_builder.lock();
CaseBuilderHandle::new(guard)
#[new]
pub fn new(expr: Option<PyExpr>) -> Self {
Self {
expr: expr.map(Into::into),
when: vec![],
then: vec![],
}
}

pub fn into_case_builder(self) -> PyDataFusionResult<CaseBuilder> {
let guard = self.case_builder.lock();
CaseBuilderHandle::new(guard).map(CaseBuilderHandle::into_inner)
}
}
pub fn when(&self, when: PyExpr, then: PyExpr) -> PyCaseBuilder {
println!("when called {self:?}");
let mut case_builder = self.clone();
case_builder.when.push(when.into());
case_builder.then.push(then.into());

#[pymethods]
impl PyCaseBuilder {
fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult<PyCaseBuilder> {
let mut handle = self.case_builder_handle()?;
let next_builder = handle.builder_mut().when(when.expr, then.expr);
Ok(next_builder.into())
case_builder
}

fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult<PyExpr> {
let mut handle = self.case_builder_handle()?;
match handle.builder_mut().otherwise(else_expr.expr) {
Ok(expr) => Ok(expr.clone().into()),
Err(err) => Err(err.into()),
}
println!("otherwise called {self:?}");
let case_builder = CaseBuilder::new(
self.expr.clone().map(Box::new),
self.when.clone(),
self.then.clone(),
Some(Box::new(else_expr.into())),
);

let expr = case_builder.end()?;

Ok(expr.into())
}

fn end(&self) -> PyDataFusionResult<PyExpr> {
let mut handle = self.case_builder_handle()?;
match handle.builder_mut().end() {
Ok(expr) => Ok(expr.clone().into()),
Err(err) => Err(err.into()),
}
println!("end called {self:?}");

let case_builder = CaseBuilder::new(
self.expr.clone().map(Box::new),
self.when.clone(),
self.then.clone(),
None,
);

let expr = case_builder.end()?;

Ok(expr.into())
}
}
4 changes: 2 additions & 2 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,13 @@ fn col(name: &str) -> PyResult<PyExpr> {
/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
#[pyfunction]
fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
Ok(datafusion::logical_expr::case(expr.expr).into())
Ok(PyCaseBuilder::new(Some(expr)))
}

/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
#[pyfunction]
fn when(when: PyExpr, then: PyExpr) -> PyResult<PyCaseBuilder> {
Ok(datafusion::logical_expr::when(when.expr, then.expr).into())
Ok(PyCaseBuilder::new(None).when(when, then))
}

/// Helper function to find the appropriate window function.
Expand Down