Skip to content

Commit fc9ab6b

Browse files
committed
feat(cube): Add cube_ext and string_boolean_coercion
1 parent a43ce8b commit fc9ab6b

File tree

7 files changed

+320
-0
lines changed

7 files changed

+320
-0
lines changed

datafusion/core/src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,14 @@ pub mod variable {
610610
pub use datafusion_expr::var_provider::{VarProvider, VarType};
611611
}
612612

613+
pub mod cube_ext {
614+
pub use datafusion_physical_plan::cube_ext::*;
615+
}
616+
617+
pub mod dfschema {
618+
pub use datafusion_common::*;
619+
}
620+
613621
#[cfg(test)]
614622
pub mod test;
615623
pub mod test_util;

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
502502
.or_else(|| list_coercion(lhs_type, rhs_type))
503503
.or_else(|| null_coercion(lhs_type, rhs_type))
504504
.or_else(|| string_numeric_coercion(lhs_type, rhs_type))
505+
.or_else(|| string_boolean_coercion(lhs_type, rhs_type))
505506
.or_else(|| string_temporal_coercion(lhs_type, rhs_type))
506507
.or_else(|| binary_coercion(lhs_type, rhs_type))
507508
.or_else(|| struct_coercion(lhs_type, rhs_type))
@@ -536,6 +537,19 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
536537
}
537538
}
538539

540+
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
541+
/// where one is boolean and one is `Utf8`/`LargeUtf8`.
542+
fn string_boolean_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
543+
use arrow::datatypes::DataType::*;
544+
match (lhs_type, rhs_type) {
545+
(Utf8, Boolean) => Some(Utf8),
546+
(LargeUtf8, Boolean) => Some(LargeUtf8),
547+
(Boolean, Utf8) => Some(Utf8),
548+
(Boolean, LargeUtf8) => Some(LargeUtf8),
549+
_ => None,
550+
}
551+
}
552+
539553
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
540554
/// where one is temporal and one is `Utf8View`/`Utf8`/`LargeUtf8`.
541555
///

datafusion/physical-plan/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ parking_lot = { workspace = true }
6666
pin-project-lite = "^0.2.7"
6767
rand = { workspace = true }
6868
tokio = { workspace = true }
69+
serde = { version = "1.0.214", features = ["derive"] }
70+
tracing = "0.1.25"
71+
tracing-futures = { version = "0.2.5" }
6972

7073
[dev-dependencies]
7174
rstest = { workspace = true }
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::error::ArrowError;
19+
use futures::future::FutureExt;
20+
use std::fmt::{Display, Formatter};
21+
use std::future::Future;
22+
use std::panic::{catch_unwind, AssertUnwindSafe};
23+
use datafusion_common::DataFusionError;
24+
25+
#[derive(PartialEq, Debug)]
26+
pub struct PanicError {
27+
pub msg: String,
28+
}
29+
30+
impl PanicError {
31+
pub fn new(msg: String) -> PanicError {
32+
PanicError { msg }
33+
}
34+
}
35+
36+
impl Display for PanicError {
37+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
38+
write!(f, "Panic: {}", self.msg)
39+
}
40+
}
41+
42+
impl From<PanicError> for ArrowError {
43+
fn from(error: PanicError) -> Self {
44+
ArrowError::ComputeError(format!("Panic: {}", error.msg))
45+
}
46+
}
47+
48+
impl From<PanicError> for DataFusionError {
49+
fn from(error: PanicError) -> Self {
50+
DataFusionError::Internal(error.msg)
51+
}
52+
}
53+
54+
pub fn try_with_catch_unwind<F, R>(f: F) -> Result<R, PanicError>
55+
where
56+
F: FnOnce() -> R,
57+
{
58+
let result = catch_unwind(AssertUnwindSafe(f));
59+
match result {
60+
Ok(x) => Ok(x),
61+
Err(e) => match e.downcast::<String>() {
62+
Ok(s) => Err(PanicError::new(*s)),
63+
Err(e) => match e.downcast::<&str>() {
64+
Ok(m1) => Err(PanicError::new(m1.to_string())),
65+
Err(_) => Err(PanicError::new("unknown cause".to_string())),
66+
},
67+
},
68+
}
69+
}
70+
71+
pub async fn async_try_with_catch_unwind<F, R>(future: F) -> Result<R, PanicError>
72+
where
73+
F: Future<Output = R>,
74+
{
75+
let result = AssertUnwindSafe(future).catch_unwind().await;
76+
match result {
77+
Ok(x) => Ok(x),
78+
Err(e) => match e.downcast::<String>() {
79+
Ok(s) => Err(PanicError::new(*s)),
80+
Err(e) => match e.downcast::<&str>() {
81+
Ok(m1) => Err(PanicError::new(m1.to_string())),
82+
Err(_) => Err(PanicError::new("unknown cause".to_string())),
83+
},
84+
},
85+
}
86+
}
87+
88+
#[cfg(test)]
89+
mod tests {
90+
use super::*;
91+
use std::panic;
92+
93+
#[test]
94+
fn test_try_with_catch_unwind() {
95+
assert_eq!(
96+
try_with_catch_unwind(|| "ok".to_string()),
97+
Ok("ok".to_string())
98+
);
99+
assert_eq!(
100+
try_with_catch_unwind(|| panic!("oops")),
101+
Err(PanicError::new("oops".to_string()))
102+
);
103+
assert_eq!(
104+
try_with_catch_unwind(|| panic!("oops{}", "ie")),
105+
Err(PanicError::new("oopsie".to_string()))
106+
);
107+
}
108+
109+
#[tokio::test]
110+
async fn test_async_try_with_catch_unwind() {
111+
assert_eq!(
112+
async_try_with_catch_unwind(async { "ok".to_string() }).await,
113+
Ok("ok".to_string())
114+
);
115+
assert_eq!(
116+
async_try_with_catch_unwind(async { panic!("oops") }).await,
117+
Err(PanicError::new("oops".to_string()))
118+
);
119+
assert_eq!(
120+
async_try_with_catch_unwind(async { panic!("oops{}", "ie") }).await,
121+
Err(PanicError::new("oopsie".to_string()))
122+
);
123+
}
124+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
pub mod catch_unwind;
19+
20+
mod spawn;
21+
pub use spawn::*;
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::future::Future;
19+
use crate::cube_ext::catch_unwind::{
20+
async_try_with_catch_unwind, try_with_catch_unwind, PanicError,
21+
};
22+
use futures::sink::SinkExt;
23+
use tokio::task::JoinHandle;
24+
use tracing_futures::Instrument;
25+
26+
/// Calls [tokio::spawn] and additionally enables tracing of the spawned task as part of the current
27+
/// computation. This is CubeStore approach to tracing, so all code must use this function instead
28+
/// of replace [tokio::spawn].
29+
pub fn spawn<T>(task: T) -> JoinHandle<T::Output>
30+
where
31+
T: Future + Send + 'static,
32+
T::Output: Send + 'static,
33+
{
34+
if let Some(s) = new_subtask_span() {
35+
tokio::spawn(async move {
36+
let _p = s.parent; // ensure parent stays alive.
37+
task.instrument(s.child).await
38+
})
39+
} else {
40+
tokio::spawn(task)
41+
}
42+
}
43+
44+
/// Propagates current span to blocking operation. See [spawn] for details.
45+
pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
46+
where
47+
F: FnOnce() -> R + Send + 'static,
48+
R: Send + 'static,
49+
{
50+
if let Some(s) = new_subtask_span() {
51+
tokio::task::spawn_blocking(move || {
52+
let _p = s.parent; // ensure parent stays alive.
53+
s.child.in_scope(f)
54+
})
55+
} else {
56+
tokio::task::spawn_blocking(f)
57+
}
58+
}
59+
60+
struct SpawnSpans {
61+
parent: tracing::Span,
62+
child: tracing::Span,
63+
}
64+
65+
fn new_subtask_span() -> Option<SpawnSpans> {
66+
let parent = tracing::Span::current();
67+
if parent.is_disabled() {
68+
return None;
69+
}
70+
// TODO: ensure this is always enabled.
71+
let child = tracing::info_span!(parent: &parent, "subtask");
72+
Some(SpawnSpans { parent, child })
73+
}
74+
75+
/// Executes future [f] in a new tokio thread. Catches panics.
76+
pub fn spawn_with_catch_unwind<F, T, E>(f: F) -> JoinHandle<Result<T, E>>
77+
where
78+
F: Future<Output = Result<T, E>> + Send + 'static,
79+
T: Send + 'static,
80+
E: From<PanicError> + Send + 'static,
81+
{
82+
let task = async move {
83+
match async_try_with_catch_unwind(f).await {
84+
Ok(result) => result,
85+
Err(panic) => Err(E::from(panic)),
86+
}
87+
};
88+
spawn(task)
89+
}
90+
91+
/// Executes future [f] in a new tokio thread. Feeds the result into [tx] oneshot channel. Catches panics.
92+
pub fn spawn_oneshot_with_catch_unwind<F, T, E>(
93+
f: F,
94+
tx: futures::channel::oneshot::Sender<Result<T, E>>,
95+
) -> JoinHandle<Result<(), Result<T, E>>>
96+
where
97+
F: Future<Output = Result<T, E>> + Send + 'static,
98+
T: Send + 'static,
99+
E: From<PanicError> + Send + 'static,
100+
{
101+
let task = async move {
102+
match async_try_with_catch_unwind(f).await {
103+
Ok(result) => tx.send(result),
104+
Err(panic) => tx.send(Err(E::from(panic))),
105+
}
106+
};
107+
spawn(task)
108+
}
109+
110+
/// Executes future [f] in a new tokio thread. Catches panics and feeds them into a [tx] mpsc channel
111+
pub fn spawn_mpsc_with_catch_unwind<F, T, E>(
112+
f: F,
113+
mut tx: futures::channel::mpsc::Sender<Result<T, E>>,
114+
) -> JoinHandle<()>
115+
where
116+
F: Future<Output = ()> + Send + 'static,
117+
T: Send + 'static,
118+
E: From<PanicError> + Send + 'static,
119+
{
120+
let task = async move {
121+
match async_try_with_catch_unwind(f).await {
122+
Ok(_) => (),
123+
Err(panic) => {
124+
tx.send(Err(E::from(panic))).await.ok();
125+
}
126+
}
127+
};
128+
spawn(task)
129+
}
130+
131+
/// Executes fn [f] in a new tokio thread. Catches panics and feeds them into a [tx] mpsc channel.
132+
pub fn spawn_blocking_mpsc_with_catch_unwind<F, R, T, E>(
133+
f: F,
134+
tx: tokio::sync::mpsc::Sender<Result<T, E>>,
135+
) -> JoinHandle<()>
136+
where
137+
F: FnOnce() -> R + Send + 'static,
138+
R: Send + 'static,
139+
T: Send + 'static,
140+
E: From<PanicError> + Send + 'static,
141+
{
142+
let task = move || match try_with_catch_unwind(f) {
143+
Ok(_) => (),
144+
Err(panic) => {
145+
tx.blocking_send(Err(E::from(panic))).ok();
146+
}
147+
};
148+
spawn_blocking(task)
149+
}

datafusion/physical-plan/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ pub mod unnest;
8080
pub mod values;
8181
pub mod windows;
8282
pub mod work_table;
83+
pub mod cube_ext;
8384

8485
pub mod udaf {
8586
pub use datafusion_physical_expr::aggregate::AggregateFunctionExpr;

0 commit comments

Comments
 (0)