Skip to content

Commit a905154

Browse files
committed
Resolve patch apply conflicts for CaseBuilder concurrency improvements
- Added CaseBuilderHandle guard that keeps the underlying CaseBuilder alive while holding the mutex and restores it on drop - Updated when, otherwise, and end methods to operate through the guard and consume the builder explicitly - This prevents transient None states during concurrent access and improves thread safety
1 parent 5caec09 commit a905154

File tree

2 files changed

+76
-25
lines changed

2 files changed

+76
-25
lines changed

python/tests/test_expr.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717

1818
import re
19+
from concurrent.futures import ThreadPoolExecutor
1920
from datetime import datetime, timezone
2021

2122
import pyarrow as pa
@@ -281,6 +282,26 @@ def test_case_builder_when_handles_are_independent():
281282
]
282283

283284

285+
def test_case_builder_when_thread_safe():
286+
case_builder = functions.when(lit(_TRUE), lit(1))
287+
288+
def build_expr(value: int) -> bool:
289+
builder = case_builder.when(lit(_TRUE), lit(value))
290+
builder.otherwise(lit(value))
291+
return True
292+
293+
with ThreadPoolExecutor(max_workers=8) as executor:
294+
futures = [executor.submit(build_expr, idx) for idx in range(16)]
295+
results = [future.result() for future in futures]
296+
297+
assert all(results)
298+
299+
# Ensure the shared builder remains usable after concurrent `when` calls.
300+
follow_up_builder = case_builder.when(lit(_TRUE), lit(42))
301+
assert isinstance(follow_up_builder, type(case_builder))
302+
follow_up_builder.otherwise(lit(7))
303+
304+
284305
def test_expr_getitem() -> None:
285306
ctx = SessionContext()
286307
data = {

src/expr/conditional_expr.rs

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,46 @@ use crate::{
2222
expr::PyExpr,
2323
};
2424
use datafusion::logical_expr::conditional_expressions::CaseBuilder;
25+
use parking_lot::{Mutex, MutexGuard};
2526
use pyo3::prelude::*;
2627

27-
use parking_lot::{Mutex, MutexGuard};
28+
struct CaseBuilderHandle<'a> {
29+
guard: MutexGuard<'a, Option<CaseBuilder>>,
30+
builder: Option<CaseBuilder>,
31+
}
32+
33+
impl<'a> CaseBuilderHandle<'a> {
34+
fn new(mut guard: MutexGuard<'a, Option<CaseBuilder>>) -> PyDataFusionResult<Self> {
35+
let builder = guard.take().ok_or_else(|| {
36+
PyDataFusionError::Common("CaseBuilder has already been consumed".to_string())
37+
})?;
38+
39+
Ok(Self {
40+
guard,
41+
builder: Some(builder),
42+
})
43+
}
44+
45+
fn builder_mut(&mut self) -> &mut CaseBuilder {
46+
self.builder
47+
.as_mut()
48+
.expect("builder should be present while handle is alive")
49+
}
50+
51+
fn into_inner(mut self) -> CaseBuilder {
52+
self.builder
53+
.take()
54+
.expect("builder should be present when consuming handle")
55+
}
56+
}
57+
58+
impl Drop for CaseBuilderHandle<'_> {
59+
fn drop(&mut self) {
60+
if let Some(builder) = self.builder.take() {
61+
*self.guard = Some(builder);
62+
}
63+
}
64+
}
2865

2966
#[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)]
3067
#[derive(Clone)]
@@ -41,45 +78,38 @@ impl From<CaseBuilder> for PyCaseBuilder {
4178
}
4279

4380
impl PyCaseBuilder {
44-
fn lock_case_builder(&self) -> MutexGuard<'_, Option<CaseBuilder>> {
45-
self.case_builder.lock()
81+
fn case_builder_handle(&self) -> PyDataFusionResult<CaseBuilderHandle<'_>> {
82+
let guard = self.case_builder.lock();
83+
CaseBuilderHandle::new(guard)
4684
}
4785

4886
pub fn into_case_builder(self) -> PyDataFusionResult<CaseBuilder> {
49-
let mut guard = self.case_builder.lock();
50-
guard.take().ok_or_else(|| {
51-
PyDataFusionError::Common("CaseBuilder has already been consumed".to_string())
52-
})
87+
let guard = self.case_builder.lock();
88+
CaseBuilderHandle::new(guard).map(CaseBuilderHandle::into_inner)
5389
}
5490
}
5591

5692
#[pymethods]
5793
impl PyCaseBuilder {
5894
fn when(&self, when: PyExpr, then: PyExpr) -> PyDataFusionResult<PyCaseBuilder> {
59-
let mut guard = self.lock_case_builder();
60-
let builder = guard.as_mut().ok_or_else(|| {
61-
PyDataFusionError::Common("CaseBuilder has already been consumed".to_string())
62-
})?;
63-
let next_builder = builder.when(when.expr, then.expr);
95+
let mut handle = self.case_builder_handle()?;
96+
let next_builder = handle.builder_mut().when(when.expr, then.expr);
6497
Ok(next_builder.into())
6598
}
6699

67100
fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult<PyExpr> {
68-
let mut guard = self.lock_case_builder();
69-
let builder = guard.as_mut().ok_or_else(|| {
70-
PyDataFusionError::Common("CaseBuilder has already been consumed".to_string())
71-
})?;
72-
builder
73-
.otherwise(else_expr.expr)
74-
.map(|expr| expr.into())
75-
.map_err(Into::into)
101+
let mut handle = self.case_builder_handle()?;
102+
match handle.builder_mut().otherwise(else_expr.expr) {
103+
Ok(expr) => Ok(expr.clone().into()),
104+
Err(err) => Err(err.into()),
105+
}
76106
}
77107

78108
fn end(&self) -> PyDataFusionResult<PyExpr> {
79-
let mut guard = self.lock_case_builder();
80-
let builder = guard.as_mut().ok_or_else(|| {
81-
PyDataFusionError::Common("CaseBuilder has already been consumed".to_string())
82-
})?;
83-
builder.end().map(|expr| expr.into()).map_err(Into::into)
109+
let mut handle = self.case_builder_handle()?;
110+
match handle.builder_mut().end() {
111+
Ok(expr) => Ok(expr.clone().into()),
112+
Err(err) => Err(err.into()),
113+
}
84114
}
85115
}

0 commit comments

Comments
 (0)