Skip to content

Commit cc43766

Browse files
peter-tothalamb
andauthored
Implement Eq, PartialEq, Hash for dyn PhysicalExpr (apache#13005)
* Implement Eq, PartialEq, Hash for PhysicalExpr * Manually implement PartialEq and Hash for BinaryExpr * Port more * Complete manual derivations * fmt * add and fix docs --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 7c6f891 commit cc43766

File tree

22 files changed

+231
-368
lines changed

22 files changed

+231
-368
lines changed

datafusion/core/tests/sql/path_partition.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ use bytes::Bytes;
4747
use chrono::{TimeZone, Utc};
4848
use datafusion_expr::{col, lit, Expr, Operator};
4949
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
50-
use datafusion_physical_expr::PhysicalExpr;
5150
use futures::stream::{self, BoxStream};
5251
use object_store::{
5352
path::Path, GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta,
@@ -97,7 +96,7 @@ async fn parquet_partition_pruning_filter() -> Result<()> {
9796
assert!(pred.as_any().is::<BinaryExpr>());
9897
let pred = pred.as_any().downcast_ref::<BinaryExpr>().unwrap();
9998

100-
assert_eq!(pred, expected.as_any());
99+
assert_eq!(pred, expected.as_ref());
101100

102101
Ok(())
103102
}

datafusion/physical-expr-common/src/physical_expr.rs

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ use datafusion_expr_common::sort_properties::ExprProperties;
5252
/// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html
5353
/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html
5454
/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html
55-
pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
55+
pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash {
5656
/// Returns the physical expression as [`Any`] so that it can be
5757
/// downcast to a specific implementation.
5858
fn as_any(&self) -> &dyn Any;
@@ -141,38 +141,6 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
141141
Ok(Some(vec![]))
142142
}
143143

144-
/// Update the hash `state` with this expression requirements from
145-
/// [`Hash`].
146-
///
147-
/// This method is required to support hashing [`PhysicalExpr`]s. To
148-
/// implement it, typically the type implementing
149-
/// [`PhysicalExpr`] implements [`Hash`] and
150-
/// then the following boiler plate is used:
151-
///
152-
/// # Example:
153-
/// ```
154-
/// // User defined expression that derives Hash
155-
/// #[derive(Hash, Debug, PartialEq, Eq)]
156-
/// struct MyExpr {
157-
/// val: u64
158-
/// }
159-
///
160-
/// // impl PhysicalExpr {
161-
/// // ...
162-
/// # impl MyExpr {
163-
/// // Boiler plate to call the derived Hash impl
164-
/// fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
165-
/// use std::hash::Hash;
166-
/// let mut s = state;
167-
/// self.hash(&mut s);
168-
/// }
169-
/// // }
170-
/// # }
171-
/// ```
172-
/// Note: [`PhysicalExpr`] is not constrained by [`Hash`]
173-
/// directly because it must remain object safe.
174-
fn dyn_hash(&self, _state: &mut dyn Hasher);
175-
176144
/// Calculates the properties of this [`PhysicalExpr`] based on its
177145
/// children's properties (i.e. order and range), recursively aggregating
178146
/// the information from its children. In cases where the [`PhysicalExpr`]
@@ -183,6 +151,42 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
183151
}
184152
}
185153

154+
/// [`PhysicalExpr`] can't be constrained by [`Eq`] directly because it must remain object
155+
/// safe. To ease implementation blanket implementation is provided for [`Eq`] types.
156+
pub trait DynEq {
157+
fn dyn_eq(&self, other: &dyn Any) -> bool;
158+
}
159+
160+
impl<T: Eq + Any> DynEq for T {
161+
fn dyn_eq(&self, other: &dyn Any) -> bool {
162+
other
163+
.downcast_ref::<Self>()
164+
.map_or(false, |other| other == self)
165+
}
166+
}
167+
168+
impl PartialEq for dyn PhysicalExpr {
169+
fn eq(&self, other: &Self) -> bool {
170+
self.dyn_eq(other.as_any())
171+
}
172+
}
173+
174+
impl Eq for dyn PhysicalExpr {}
175+
176+
/// [`PhysicalExpr`] can't be constrained by [`Hash`] directly because it must remain
177+
/// object safe. To ease implementation blanket implementation is provided for [`Hash`]
178+
/// types.
179+
pub trait DynHash {
180+
fn dyn_hash(&self, _state: &mut dyn Hasher);
181+
}
182+
183+
impl<T: Hash + Any> DynHash for T {
184+
fn dyn_hash(&self, mut state: &mut dyn Hasher) {
185+
self.type_id().hash(&mut state);
186+
self.hash(&mut state)
187+
}
188+
}
189+
186190
impl Hash for dyn PhysicalExpr {
187191
fn hash<H: Hasher>(&self, state: &mut H) {
188192
self.dyn_hash(state);
@@ -210,20 +214,6 @@ pub fn with_new_children_if_necessary(
210214
}
211215
}
212216

213-
pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
214-
if any.is::<Arc<dyn PhysicalExpr>>() {
215-
any.downcast_ref::<Arc<dyn PhysicalExpr>>()
216-
.unwrap()
217-
.as_any()
218-
} else if any.is::<Box<dyn PhysicalExpr>>() {
219-
any.downcast_ref::<Box<dyn PhysicalExpr>>()
220-
.unwrap()
221-
.as_any()
222-
} else {
223-
any
224-
}
225-
}
226-
227217
/// Returns [`Display`] able a list of [`PhysicalExpr`]
228218
///
229219
/// Example output: `[a + 1, b]`

datafusion/physical-expr-common/src/sort_expr.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,10 @@ use itertools::Itertools;
5858
/// # fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {todo!() }
5959
/// # fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {todo!()}
6060
/// # fn with_new_children(self: Arc<Self>, children: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn PhysicalExpr>> {todo!()}
61-
/// # fn dyn_hash(&self, _state: &mut dyn Hasher) {todo!()}
6261
/// # }
6362
/// # impl Display for MyPhysicalExpr {
6463
/// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "a") }
6564
/// # }
66-
/// # impl PartialEq<dyn Any> for MyPhysicalExpr {
67-
/// # fn eq(&self, _other: &dyn Any) -> bool { true }
68-
/// # }
6965
/// # fn col(name: &str) -> Arc<dyn PhysicalExpr> { Arc::new(MyPhysicalExpr) }
7066
/// // Sort by a ASC
7167
/// let options = SortOptions::default();

datafusion/physical-expr/src/equivalence/class.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ pub struct ConstExpr {
6666

6767
impl PartialEq for ConstExpr {
6868
fn eq(&self, other: &Self) -> bool {
69-
self.across_partitions == other.across_partitions
70-
&& self.expr.eq(other.expr.as_any())
69+
self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
7170
}
7271
}
7372

@@ -120,7 +119,7 @@ impl ConstExpr {
120119

121120
/// Returns true if this constant expression is equal to the given expression
122121
pub fn eq_expr(&self, other: impl AsRef<dyn PhysicalExpr>) -> bool {
123-
self.expr.eq(other.as_ref().as_any())
122+
self.expr.as_ref() == other.as_ref()
124123
}
125124

126125
/// Returns a [`Display`]able list of `ConstExpr`.
@@ -556,7 +555,7 @@ impl EquivalenceGroup {
556555
new_classes.push((source, vec![Arc::clone(target)]));
557556
}
558557
if let Some((_, values)) =
559-
new_classes.iter_mut().find(|(key, _)| key.eq(source))
558+
new_classes.iter_mut().find(|(key, _)| *key == source)
560559
{
561560
if !physical_exprs_contains(values, target) {
562561
values.push(Arc::clone(target));

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717

1818
mod kernels;
1919

20-
use std::hash::{Hash, Hasher};
20+
use std::hash::Hash;
2121
use std::{any::Any, sync::Arc};
2222

2323
use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
24-
use crate::physical_expr::down_cast_any_ref;
2524
use crate::PhysicalExpr;
2625

2726
use arrow::array::*;
@@ -48,7 +47,7 @@ use kernels::{
4847
};
4948

5049
/// Binary expression
51-
#[derive(Debug, Hash, Clone)]
50+
#[derive(Debug, Clone, Eq)]
5251
pub struct BinaryExpr {
5352
left: Arc<dyn PhysicalExpr>,
5453
op: Operator,
@@ -57,6 +56,24 @@ pub struct BinaryExpr {
5756
fail_on_overflow: bool,
5857
}
5958

59+
// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
60+
impl PartialEq for BinaryExpr {
61+
fn eq(&self, other: &Self) -> bool {
62+
self.left.eq(&other.left)
63+
&& self.op.eq(&other.op)
64+
&& self.right.eq(&other.right)
65+
&& self.fail_on_overflow.eq(&other.fail_on_overflow)
66+
}
67+
}
68+
impl Hash for BinaryExpr {
69+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
70+
self.left.hash(state);
71+
self.op.hash(state);
72+
self.right.hash(state);
73+
self.fail_on_overflow.hash(state);
74+
}
75+
}
76+
6077
impl BinaryExpr {
6178
/// Create new binary expression
6279
pub fn new(
@@ -477,11 +494,6 @@ impl PhysicalExpr for BinaryExpr {
477494
}
478495
}
479496

480-
fn dyn_hash(&self, state: &mut dyn Hasher) {
481-
let mut s = state;
482-
self.hash(&mut s);
483-
}
484-
485497
/// For each operator, [`BinaryExpr`] has distinct rules.
486498
/// TODO: There may be rules specific to some data types and expression ranges.
487499
fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
@@ -525,20 +537,6 @@ impl PhysicalExpr for BinaryExpr {
525537
}
526538
}
527539

528-
impl PartialEq<dyn Any> for BinaryExpr {
529-
fn eq(&self, other: &dyn Any) -> bool {
530-
down_cast_any_ref(other)
531-
.downcast_ref::<Self>()
532-
.map(|x| {
533-
self.left.eq(&x.left)
534-
&& self.op == x.op
535-
&& self.right.eq(&x.right)
536-
&& self.fail_on_overflow.eq(&x.fail_on_overflow)
537-
})
538-
.unwrap_or(false)
539-
}
540-
}
541-
542540
/// Casts dictionary array to result type for binary numerical operators. Such operators
543541
/// between array and scalar produce a dictionary array other than primitive array of the
544542
/// same operators between array and array. This leads to inconsistent result types causing

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
// under the License.
1717

1818
use std::borrow::Cow;
19-
use std::hash::{Hash, Hasher};
19+
use std::hash::Hash;
2020
use std::{any::Any, sync::Arc};
2121

2222
use crate::expressions::try_cast;
23-
use crate::physical_expr::down_cast_any_ref;
2423
use crate::PhysicalExpr;
2524

2625
use arrow::array::*;
@@ -37,7 +36,7 @@ use itertools::Itertools;
3736

3837
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
3938

40-
#[derive(Debug, Hash)]
39+
#[derive(Debug, Hash, PartialEq, Eq)]
4140
enum EvalMethod {
4241
/// CASE WHEN condition THEN result
4342
/// [WHEN ...]
@@ -80,7 +79,7 @@ enum EvalMethod {
8079
/// [WHEN ...]
8180
/// [ELSE result]
8281
/// END
83-
#[derive(Debug, Hash)]
82+
#[derive(Debug, Hash, PartialEq, Eq)]
8483
pub struct CaseExpr {
8584
/// Optional base expression that can be compared to literal values in the "when" expressions
8685
expr: Option<Arc<dyn PhysicalExpr>>,
@@ -506,39 +505,6 @@ impl PhysicalExpr for CaseExpr {
506505
)?))
507506
}
508507
}
509-
510-
fn dyn_hash(&self, state: &mut dyn Hasher) {
511-
let mut s = state;
512-
self.hash(&mut s);
513-
}
514-
}
515-
516-
impl PartialEq<dyn Any> for CaseExpr {
517-
fn eq(&self, other: &dyn Any) -> bool {
518-
down_cast_any_ref(other)
519-
.downcast_ref::<Self>()
520-
.map(|x| {
521-
let expr_eq = match (&self.expr, &x.expr) {
522-
(Some(expr1), Some(expr2)) => expr1.eq(expr2),
523-
(None, None) => true,
524-
_ => false,
525-
};
526-
let else_expr_eq = match (&self.else_expr, &x.else_expr) {
527-
(Some(expr1), Some(expr2)) => expr1.eq(expr2),
528-
(None, None) => true,
529-
_ => false,
530-
};
531-
expr_eq
532-
&& else_expr_eq
533-
&& self.when_then_expr.len() == x.when_then_expr.len()
534-
&& self.when_then_expr.iter().zip(x.when_then_expr.iter()).all(
535-
|((when1, then1), (when2, then2))| {
536-
when1.eq(when2) && then1.eq(then2)
537-
},
538-
)
539-
})
540-
.unwrap_or(false)
541-
}
542508
}
543509

544510
/// Create a CASE expression

0 commit comments

Comments
 (0)