|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -use std::sync::Arc; |
19 | | - |
20 | | -use crate::{ |
21 | | - errors::{PyDataFusionError, PyDataFusionResult}, |
22 | | - expr::PyExpr, |
23 | | -}; |
| 18 | +use crate::{errors::PyDataFusionResult, expr::PyExpr}; |
| 19 | +use datafusion::common::{exec_err, DataFusionError}; |
24 | 20 | use datafusion::logical_expr::conditional_expressions::CaseBuilder; |
25 | | -use parking_lot::{Mutex, MutexGuard}; |
| 21 | +use datafusion::prelude::Expr; |
26 | 22 | use pyo3::prelude::*; |
27 | 23 |
|
28 | | -struct CaseBuilderHandle<'a> { |
29 | | - guard: MutexGuard<'a, Option<CaseBuilder>>, |
30 | | - builder: Option<CaseBuilder>, |
31 | | -} |
32 | | - |
33 | | -impl<'a> CaseBuilderHandle<'a> { |
34 | | - fn new(mut guard: MutexGuard<'a, Option<CaseBuilder>>) -> PyDataFusionResult<Self> { |
35 | | - let builder = guard.take().ok_or_else(|| { |
36 | | - PyDataFusionError::Common("CaseBuilder has already been consumed".to_string()) |
37 | | - })?; |
38 | | - |
39 | | - Ok(Self { |
40 | | - guard, |
41 | | - builder: Some(builder), |
42 | | - }) |
43 | | - } |
44 | | - |
45 | | - fn builder_mut(&mut self) -> &mut CaseBuilder { |
46 | | - self.builder |
47 | | - .as_mut() |
48 | | - .expect("builder should be present while handle is alive") |
49 | | - } |
50 | | - |
51 | | - fn into_inner(mut self) -> CaseBuilder { |
52 | | - self.builder |
53 | | - .take() |
54 | | - .expect("builder should be present when consuming handle") |
55 | | - } |
56 | | -} |
57 | | - |
58 | | -impl Drop for CaseBuilderHandle<'_> { |
59 | | - fn drop(&mut self) { |
60 | | - if let Some(builder) = self.builder.take() { |
61 | | - *self.guard = Some(builder); |
62 | | - } |
63 | | - } |
64 | | -} |
65 | | - |
66 | 24 | #[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)] |
67 | | -#[derive(Clone)] |
68 | 25 | pub struct PyCaseBuilder { |
69 | | - case_builder: Arc<Mutex<Option<CaseBuilder>>>, |
| 26 | + case_builder: CaseBuilder, |
70 | 27 | } |
71 | 28 |
|
72 | 29 | impl From<CaseBuilder> for PyCaseBuilder { |
73 | 30 | fn from(case_builder: CaseBuilder) -> PyCaseBuilder { |
74 | | - PyCaseBuilder { |
75 | | - case_builder: Arc::new(Mutex::new(Some(case_builder))), |
76 | | - } |
| 31 | + PyCaseBuilder { case_builder } |
77 | 32 | } |
78 | 33 | } |
79 | 34 |
|
80 | | -impl PyCaseBuilder { |
81 | | - fn case_builder_handle(&self) -> PyDataFusionResult<CaseBuilderHandle<'_>> { |
82 | | - let guard = self.case_builder.lock(); |
83 | | - CaseBuilderHandle::new(guard) |
84 | | - } |
| 35 | +// TODO(tsaucer) upstream make CaseBuilder impl Clone |
| 36 | +fn builder_clone(case_builder: &CaseBuilder) -> Result<CaseBuilder, DataFusionError> { |
| 37 | + let Expr::Case(case) = case_builder.end()? else { |
| 38 | + return exec_err!("CaseBuilder returned an invalid expression"); |
| 39 | + }; |
85 | 40 |
|
86 | | - pub fn into_case_builder(self) -> PyDataFusionResult<CaseBuilder> { |
87 | | - let guard = self.case_builder.lock(); |
88 | | - CaseBuilderHandle::new(guard).map(CaseBuilderHandle::into_inner) |
89 | | - } |
| 41 | + let (when_expr, then_expr) = case |
| 42 | + .when_then_expr |
| 43 | + .iter() |
| 44 | + .map(|(w, t)| (w.as_ref().to_owned(), t.as_ref().to_owned())) |
| 45 | + .unzip(); |
| 46 | + |
| 47 | + Ok(CaseBuilder::new( |
| 48 | + case.expr, |
| 49 | + when_expr, |
| 50 | + then_expr, |
| 51 | + case.else_expr, |
| 52 | + )) |
90 | 53 | } |
91 | 54 |
|
92 | 55 | #[pymethods] |
93 | 56 | impl PyCaseBuilder { |
94 | 57 | fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult<PyCaseBuilder> { |
95 | | - let mut handle = self.case_builder_handle()?; |
96 | | - let next_builder = handle.builder_mut().when(when.expr, then.expr); |
97 | | - Ok(next_builder.into()) |
| 58 | + let case_builder = builder_clone(&self.case_builder)?.when(when.expr, then.expr); |
| 59 | + Ok(PyCaseBuilder { case_builder }) |
98 | 60 | } |
99 | 61 |
|
100 | 62 | fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult<PyExpr> { |
101 | | - let mut handle = self.case_builder_handle()?; |
102 | | - match handle.builder_mut().otherwise(else_expr.expr) { |
103 | | - Ok(expr) => Ok(expr.clone().into()), |
104 | | - Err(err) => Err(err.into()), |
105 | | - } |
| 63 | + Ok(builder_clone(&self.case_builder)? |
| 64 | + .otherwise(else_expr.expr)? |
| 65 | + .into()) |
106 | 66 | } |
107 | 67 |
|
108 | 68 | fn end(&self) -> PyDataFusionResult<PyExpr> { |
109 | | - let mut handle = self.case_builder_handle()?; |
110 | | - match handle.builder_mut().end() { |
111 | | - Ok(expr) => Ok(expr.clone().into()), |
112 | | - Err(err) => Err(err.into()), |
113 | | - } |
| 69 | + Ok(builder_clone(&self.case_builder)?.end()?.into()) |
114 | 70 | } |
115 | 71 | } |
0 commit comments