Skip to content

Commit 8a52e23

Browse files
committed
Alternate approach to case expression
1 parent 2c76271 commit 8a52e23

File tree

1 file changed

+28
-72
lines changed

1 file changed

+28
-72
lines changed

src/expr/conditional_expr.rs

Lines changed: 28 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,101 +15,57 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::Arc;
19-
20-
use crate::{
21-
errors::{PyDataFusionError, PyDataFusionResult},
22-
expr::PyExpr,
23-
};
18+
use crate::{errors::PyDataFusionResult, expr::PyExpr};
19+
use datafusion::common::{exec_err, DataFusionError};
2420
use datafusion::logical_expr::conditional_expressions::CaseBuilder;
25-
use parking_lot::{Mutex, MutexGuard};
21+
use datafusion::prelude::Expr;
2622
use pyo3::prelude::*;
2723

28-
struct CaseBuilderHandle<'a> {
29-
guard: MutexGuard<'a, Option<CaseBuilder>>,
30-
builder: Option<CaseBuilder>,
31-
}
32-
33-
impl<'a> CaseBuilderHandle<'a> {
34-
fn new(mut guard: MutexGuard<'a, Option<CaseBuilder>>) -> PyDataFusionResult<Self> {
35-
let builder = guard.take().ok_or_else(|| {
36-
PyDataFusionError::Common("CaseBuilder has already been consumed".to_string())
37-
})?;
38-
39-
Ok(Self {
40-
guard,
41-
builder: Some(builder),
42-
})
43-
}
44-
45-
fn builder_mut(&mut self) -> &mut CaseBuilder {
46-
self.builder
47-
.as_mut()
48-
.expect("builder should be present while handle is alive")
49-
}
50-
51-
fn into_inner(mut self) -> CaseBuilder {
52-
self.builder
53-
.take()
54-
.expect("builder should be present when consuming handle")
55-
}
56-
}
57-
58-
impl Drop for CaseBuilderHandle<'_> {
59-
fn drop(&mut self) {
60-
if let Some(builder) = self.builder.take() {
61-
*self.guard = Some(builder);
62-
}
63-
}
64-
}
65-
6624
#[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)]
67-
#[derive(Clone)]
6825
pub struct PyCaseBuilder {
69-
case_builder: Arc<Mutex<Option<CaseBuilder>>>,
26+
case_builder: CaseBuilder,
7027
}
7128

7229
impl From<CaseBuilder> for PyCaseBuilder {
7330
fn from(case_builder: CaseBuilder) -> PyCaseBuilder {
74-
PyCaseBuilder {
75-
case_builder: Arc::new(Mutex::new(Some(case_builder))),
76-
}
31+
PyCaseBuilder { case_builder }
7732
}
7833
}
7934

80-
impl PyCaseBuilder {
81-
fn case_builder_handle(&self) -> PyDataFusionResult<CaseBuilderHandle<'_>> {
82-
let guard = self.case_builder.lock();
83-
CaseBuilderHandle::new(guard)
84-
}
35+
// TODO(tsaucer) upstream make CaseBuilder impl Clone
36+
fn builder_clone(case_builder: &CaseBuilder) -> Result<CaseBuilder, DataFusionError> {
37+
let Expr::Case(case) = case_builder.end()? else {
38+
return exec_err!("CaseBuilder returned an invalid expression");
39+
};
8540

86-
pub fn into_case_builder(self) -> PyDataFusionResult<CaseBuilder> {
87-
let guard = self.case_builder.lock();
88-
CaseBuilderHandle::new(guard).map(CaseBuilderHandle::into_inner)
89-
}
41+
let (when_expr, then_expr) = case
42+
.when_then_expr
43+
.iter()
44+
.map(|(w, t)| (w.as_ref().to_owned(), t.as_ref().to_owned()))
45+
.unzip();
46+
47+
Ok(CaseBuilder::new(
48+
case.expr,
49+
when_expr,
50+
then_expr,
51+
case.else_expr,
52+
))
9053
}
9154

9255
#[pymethods]
9356
impl PyCaseBuilder {
9457
fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult<PyCaseBuilder> {
95-
let mut handle = self.case_builder_handle()?;
96-
let next_builder = handle.builder_mut().when(when.expr, then.expr);
97-
Ok(next_builder.into())
58+
let case_builder = builder_clone(&self.case_builder)?.when(when.expr, then.expr);
59+
Ok(PyCaseBuilder { case_builder })
9860
}
9961

10062
fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult<PyExpr> {
101-
let mut handle = self.case_builder_handle()?;
102-
match handle.builder_mut().otherwise(else_expr.expr) {
103-
Ok(expr) => Ok(expr.clone().into()),
104-
Err(err) => Err(err.into()),
105-
}
63+
Ok(builder_clone(&self.case_builder)?
64+
.otherwise(else_expr.expr)?
65+
.into())
10666
}
10767

10868
fn end(&self) -> PyDataFusionResult<PyExpr> {
109-
let mut handle = self.case_builder_handle()?;
110-
match handle.builder_mut().end() {
111-
Ok(expr) => Ok(expr.clone().into()),
112-
Err(err) => Err(err.into()),
113-
}
69+
Ok(builder_clone(&self.case_builder)?.end()?.into())
11470
}
11571
}

0 commit comments

Comments
 (0)