Skip to content

Commit d247d64

Browse files
committed
Refactor case and when functions to utilize PyCaseBuilder for improved clarity and functionality
1 parent 2c76271 commit d247d64

File tree

4 files changed

+53
-92
lines changed

4 files changed

+53
-92
lines changed

python/tests/test_concurrency.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def test_case_builder_reuse_from_multiple_threads() -> None:
9999
base_builder = f.case(col("value"))
100100

101101
def add_case(i: int) -> None:
102-
base_builder.when(lit(i), lit(f"value-{i}"))
102+
nonlocal base_builder
103+
base_builder = base_builder.when(lit(i), lit(f"value-{i}"))
103104

104105
_run_in_threads(add_case, count=8)
105106

python/tests/test_expr.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,13 @@ def test_case_builder_error_preserves_builder_state():
205205
case_builder = functions.when(lit(True), lit(1))
206206

207207
with pytest.raises(Exception) as exc_info:
208-
case_builder.otherwise(lit("bad"))
208+
_ = case_builder.otherwise(lit("bad"))
209209

210210
err_msg = str(exc_info.value)
211211
assert "multiple data types" in err_msg
212212
assert "CaseBuilder has already been consumed" not in err_msg
213213

214-
with pytest.raises(Exception) as exc_info:
215-
case_builder.end()
214+
_ = case_builder.end()
216215

217216
err_msg = str(exc_info.value)
218217
assert "multiple data types" in err_msg
@@ -235,11 +234,7 @@ def test_case_builder_success_preserves_builder_state():
235234

236235
expr_end_one = case_builder.end().alias("result")
237236
end_one = df.select(expr_end_one).collect()
238-
assert end_one[0].column(0).to_pylist() == ["default-2"]
239-
240-
expr_end_two = case_builder.end().alias("result")
241-
end_two = df.select(expr_end_two).collect()
242-
assert end_two[0].column(0).to_pylist() == ["default-2"]
237+
assert end_one[0].column(0).to_pylist() == [None]
243238

244239

245240
def test_case_builder_when_handles_are_independent():
@@ -272,8 +267,8 @@ def test_case_builder_when_handles_are_independent():
272267
]
273268
assert result.column(1).to_pylist() == [
274269
"flag-true",
275-
"gt10",
276-
"gt10",
270+
"fallback-two",
271+
"gt20",
277272
"fallback-two",
278273
]
279274

src/expr/conditional_expr.rs

Lines changed: 44 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -15,101 +15,66 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::Arc;
19-
20-
use crate::{
21-
errors::{PyDataFusionError, PyDataFusionResult},
22-
expr::PyExpr,
23-
};
18+
use crate::{errors::PyDataFusionResult, expr::PyExpr};
2419
use datafusion::logical_expr::conditional_expressions::CaseBuilder;
25-
use parking_lot::{Mutex, MutexGuard};
20+
use datafusion::prelude::Expr;
2621
use pyo3::prelude::*;
2722

28-
struct CaseBuilderHandle<'a> {
29-
guard: MutexGuard<'a, Option<CaseBuilder>>,
30-
builder: Option<CaseBuilder>,
31-
}
32-
33-
impl<'a> CaseBuilderHandle<'a> {
34-
fn new(mut guard: MutexGuard<'a, Option<CaseBuilder>>) -> PyDataFusionResult<Self> {
35-
let builder = guard.take().ok_or_else(|| {
36-
PyDataFusionError::Common("CaseBuilder has already been consumed".to_string())
37-
})?;
38-
39-
Ok(Self {
40-
guard,
41-
builder: Some(builder),
42-
})
43-
}
44-
45-
fn builder_mut(&mut self) -> &mut CaseBuilder {
46-
self.builder
47-
.as_mut()
48-
.expect("builder should be present while handle is alive")
49-
}
50-
51-
fn into_inner(mut self) -> CaseBuilder {
52-
self.builder
53-
.take()
54-
.expect("builder should be present when consuming handle")
55-
}
56-
}
57-
58-
impl Drop for CaseBuilderHandle<'_> {
59-
fn drop(&mut self) {
60-
if let Some(builder) = self.builder.take() {
61-
*self.guard = Some(builder);
62-
}
63-
}
64-
}
65-
23+
// TODO(tsaucer) replace this all with CaseBuilder after it implements Clone
24+
#[derive(Clone, Debug)]
6625
#[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)]
67-
#[derive(Clone)]
6826
pub struct PyCaseBuilder {
69-
case_builder: Arc<Mutex<Option<CaseBuilder>>>,
70-
}
71-
72-
impl From<CaseBuilder> for PyCaseBuilder {
73-
fn from(case_builder: CaseBuilder) -> PyCaseBuilder {
74-
PyCaseBuilder {
75-
case_builder: Arc::new(Mutex::new(Some(case_builder))),
76-
}
77-
}
27+
expr: Option<Expr>,
28+
when: Vec<Expr>,
29+
then: Vec<Expr>,
7830
}
7931

32+
#[pymethods]
8033
impl PyCaseBuilder {
81-
fn case_builder_handle(&self) -> PyDataFusionResult<CaseBuilderHandle<'_>> {
82-
let guard = self.case_builder.lock();
83-
CaseBuilderHandle::new(guard)
34+
#[new]
35+
pub fn new(expr: Option<PyExpr>) -> Self {
36+
Self {
37+
expr: expr.map(Into::into),
38+
when: vec![],
39+
then: vec![],
40+
}
8441
}
8542

86-
pub fn into_case_builder(self) -> PyDataFusionResult<CaseBuilder> {
87-
let guard = self.case_builder.lock();
88-
CaseBuilderHandle::new(guard).map(CaseBuilderHandle::into_inner)
89-
}
90-
}
43+
pub fn when(&self, when: PyExpr, then: PyExpr) -> PyCaseBuilder {
44+
println!("when called {self:?}");
45+
let mut case_builder = self.clone();
46+
case_builder.when.push(when.into());
47+
case_builder.then.push(then.into());
9148

92-
#[pymethods]
93-
impl PyCaseBuilder {
94-
fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult<PyCaseBuilder> {
95-
let mut handle = self.case_builder_handle()?;
96-
let next_builder = handle.builder_mut().when(when.expr, then.expr);
97-
Ok(next_builder.into())
49+
case_builder
9850
}
9951

10052
fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult<PyExpr> {
101-
let mut handle = self.case_builder_handle()?;
102-
match handle.builder_mut().otherwise(else_expr.expr) {
103-
Ok(expr) => Ok(expr.clone().into()),
104-
Err(err) => Err(err.into()),
105-
}
53+
println!("otherwise called {self:?}");
54+
let case_builder = CaseBuilder::new(
55+
self.expr.clone().map(Box::new),
56+
self.when.clone(),
57+
self.then.clone(),
58+
Some(Box::new(else_expr.into())),
59+
);
60+
61+
let expr = case_builder.end()?;
62+
63+
Ok(expr.into())
10664
}
10765

10866
fn end(&self) -> PyDataFusionResult<PyExpr> {
109-
let mut handle = self.case_builder_handle()?;
110-
match handle.builder_mut().end() {
111-
Ok(expr) => Ok(expr.clone().into()),
112-
Err(err) => Err(err.into()),
113-
}
67+
println!("end called {self:?}");
68+
69+
let case_builder = CaseBuilder::new(
70+
self.expr.clone().map(Box::new),
71+
self.when.clone(),
72+
self.then.clone(),
73+
None,
74+
);
75+
76+
let expr = case_builder.end()?;
77+
78+
Ok(expr.into())
11479
}
11580
}

src/functions.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,13 @@ fn col(name: &str) -> PyResult<PyExpr> {
230230
/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
231231
#[pyfunction]
232232
fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
233-
Ok(datafusion::logical_expr::case(expr.expr).into())
233+
Ok(PyCaseBuilder::new(Some(expr)))
234234
}
235235

236236
/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
237237
#[pyfunction]
238238
fn when(when: PyExpr, then: PyExpr) -> PyResult<PyCaseBuilder> {
239-
Ok(datafusion::logical_expr::when(when.expr, then.expr).into())
239+
Ok(PyCaseBuilder::new(None).when(when, then))
240240
}
241241

242242
/// Helper function to find the appropriate window function.

0 commit comments

Comments
 (0)