diff --git a/optd-cost-model/src/common/nodes.rs b/optd-cost-model/src/common/nodes.rs index 38e2500..79a47f7 100644 --- a/optd-cost-model/src/common/nodes.rs +++ b/optd-cost-model/src/common/nodes.rs @@ -1,4 +1,5 @@ -use std::sync::Arc; +use core::fmt; +use std::{fmt::Display, sync::Arc}; use arrow_schema::DataType; @@ -24,6 +25,12 @@ pub enum JoinType { RightAnti, } +impl Display for JoinType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + /// TODO: documentation #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PhysicalNodeType { @@ -49,8 +56,7 @@ impl std::fmt::Display for PhysicalNodeType { pub enum PredicateType { List, Constant(ConstantType), - AttributeRef, - ExternAttributeRef, + AttrIndex, UnOp(UnOpType), BinOp(BinOpType), LogOp(LogOpType), @@ -77,7 +83,7 @@ pub struct PredicateNode { /// A generic predicate node type pub typ: PredicateType, /// Child predicate nodes, always materialized - pub children: Vec, + pub children: Vec, /// Data associated with the predicate, if any pub data: Option, } @@ -94,3 +100,28 @@ impl std::fmt::Display for PredicateNode { write!(f, ")") } } + +impl PredicateNode { + pub fn child(&self, idx: usize) -> ArcPredicateNode { + self.children[idx].clone() + } + + pub fn unwrap_data(&self) -> Value { + self.data.clone().unwrap() + } +} +pub trait ReprPredicateNode: 'static + Clone { + fn into_pred_node(self) -> ArcPredicateNode; + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option; +} + +impl ReprPredicateNode for ArcPredicateNode { + fn into_pred_node(self) -> ArcPredicateNode { + self + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + Some(pred_node) + } +} diff --git a/optd-cost-model/src/common/predicates/attr_index_pred.rs b/optd-cost-model/src/common/predicates/attr_index_pred.rs new file mode 100644 index 0000000..412c7a3 --- /dev/null +++ b/optd-cost-model/src/common/predicates/attr_index_pred.rs @@ -0,0 +1,42 @@ +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::Value, +}; + +/// [`AttributeIndexPred`] represents the position of an attribute in a schema or +/// [`GroupAttrRefs`]. +/// +/// The `data` field holds the index of the attribute in the schema or [`GroupAttrRefs`]. +#[derive(Clone, Debug)] +pub struct AttrIndexPred(pub ArcPredicateNode); + +impl AttrIndexPred { + pub fn new(attr_idx: u64) -> AttrIndexPred { + AttrIndexPred( + PredicateNode { + typ: PredicateType::AttrIndex, + children: vec![], + data: Some(Value::UInt64(attr_idx)), + } + .into(), + ) + } + + /// Gets the attribute index. + pub fn attr_index(&self) -> u64 { + self.0.data.as_ref().unwrap().as_u64() + } +} + +impl ReprPredicateNode for AttrIndexPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if pred_node.typ != PredicateType::AttrIndex { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/bin_op_pred.rs b/optd-cost-model/src/common/predicates/bin_op_pred.rs index 196d987..5c48688 100644 --- a/optd-cost-model/src/common/predicates/bin_op_pred.rs +++ b/optd-cost-model/src/common/predicates/bin_op_pred.rs @@ -1,3 +1,5 @@ +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum BinOpType { @@ -38,3 +40,48 @@ impl BinOpType { ) } } + +#[derive(Clone, Debug)] +pub struct BinOpPred(pub ArcPredicateNode); + +impl BinOpPred { + pub fn new(left: ArcPredicateNode, right: ArcPredicateNode, op_type: BinOpType) -> Self { + BinOpPred( + PredicateNode { + typ: PredicateType::BinOp(op_type), + children: vec![left, right], + data: None, + } + .into(), + ) + } + + pub fn left_child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn right_child(&self) -> ArcPredicateNode { + self.0.child(1) + } + + pub fn op_type(&self) -> BinOpType { + if let PredicateType::BinOp(op_type) = self.0.typ { + op_type + } else { + panic!("not a bin op") + } + } +} + +impl ReprPredicateNode for BinOpPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::BinOp(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/cast_pred.rs b/optd-cost-model/src/common/predicates/cast_pred.rs new file mode 100644 index 0000000..2e1ef54 --- /dev/null +++ b/optd-cost-model/src/common/predicates/cast_pred.rs @@ -0,0 +1,49 @@ +use arrow_schema::DataType; + +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +use super::data_type_pred::DataTypePred; + +/// [`CastPred`] casts a column from one data type to another. +/// +/// A [`CastPred`] has two children: +/// 1. The original data to cast +/// 2. The target data type to cast to +#[derive(Clone, Debug)] +pub struct CastPred(pub ArcPredicateNode); + +impl CastPred { + pub fn new(child: ArcPredicateNode, cast_to: DataType) -> Self { + CastPred( + PredicateNode { + typ: PredicateType::Cast, + children: vec![child, DataTypePred::new(cast_to).into_pred_node()], + data: None, + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn cast_to(&self) -> DataType { + DataTypePred::from_pred_node(self.0.child(1)) + .unwrap() + .data_type() + } +} + +impl ReprPredicateNode for CastPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::Cast) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/constant_pred.rs b/optd-cost-model/src/common/predicates/constant_pred.rs index 7923ae4..61285f7 100644 --- a/optd-cost-model/src/common/predicates/constant_pred.rs +++ b/optd-cost-model/src/common/predicates/constant_pred.rs @@ -1,5 +1,14 @@ +use std::sync::Arc; + +use arrow_schema::{DataType, IntervalUnit}; +use optd_persistent::cost_model::interface::AttrType; use serde::{Deserialize, Serialize}; +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::{SerializableOrderedF64, Value}, +}; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)] pub enum ConstantType { @@ -19,3 +28,193 @@ pub enum ConstantType { Decimal, Binary, } + +impl ConstantType { + pub fn get_data_type_from_value(value: &Value) -> Self { + match value { + Value::Bool(_) => ConstantType::Bool, + Value::String(_) => ConstantType::Utf8String, + Value::UInt8(_) => ConstantType::UInt8, + Value::UInt16(_) => ConstantType::UInt16, + Value::UInt32(_) => ConstantType::UInt32, + Value::UInt64(_) => ConstantType::UInt64, + Value::Int8(_) => ConstantType::Int8, + Value::Int16(_) => ConstantType::Int16, + Value::Int32(_) => ConstantType::Int32, + Value::Int64(_) => ConstantType::Int64, + Value::Float(_) => ConstantType::Float64, + Value::Date32(_) => ConstantType::Date, + _ => unimplemented!("get_data_type_from_value() not implemented for value {value}"), + } + } + + // TODO: current DataType and ConstantType are not 1 to 1 mapping + // optd schema stores constantType from data type in catalog.get + // for decimal128, the precision is lost + pub fn from_data_type(data_type: DataType) -> Self { + match data_type { + DataType::Binary => ConstantType::Binary, + DataType::Boolean => ConstantType::Bool, + DataType::UInt8 => ConstantType::UInt8, + DataType::UInt16 => ConstantType::UInt16, + DataType::UInt32 => ConstantType::UInt32, + DataType::UInt64 => ConstantType::UInt64, + DataType::Int8 => ConstantType::Int8, + DataType::Int16 => ConstantType::Int16, + DataType::Int32 => ConstantType::Int32, + DataType::Int64 => ConstantType::Int64, + DataType::Float64 => ConstantType::Float64, + DataType::Date32 => ConstantType::Date, + DataType::Interval(IntervalUnit::MonthDayNano) => ConstantType::IntervalMonthDateNano, + DataType::Utf8 => ConstantType::Utf8String, + DataType::Decimal128(_, _) => ConstantType::Decimal, + _ => unimplemented!("no conversion to ConstantType for DataType {data_type}"), + } + } + + pub fn into_data_type(&self) -> DataType { + match self { + ConstantType::Binary => DataType::Binary, + ConstantType::Bool => DataType::Boolean, + ConstantType::UInt8 => DataType::UInt8, + ConstantType::UInt16 => DataType::UInt16, + ConstantType::UInt32 => DataType::UInt32, + ConstantType::UInt64 => DataType::UInt64, + ConstantType::Int8 => DataType::Int8, + ConstantType::Int16 => DataType::Int16, + ConstantType::Int32 => DataType::Int32, + ConstantType::Int64 => DataType::Int64, + ConstantType::Float64 => DataType::Float64, + ConstantType::Date => DataType::Date32, + ConstantType::IntervalMonthDateNano => DataType::Interval(IntervalUnit::MonthDayNano), + ConstantType::Decimal => DataType::Float64, + ConstantType::Utf8String => DataType::Utf8, + } + } + + pub fn from_persistent_attr_type(attr_type: AttrType) -> Self { + match attr_type { + AttrType::Integer => ConstantType::Int32, + AttrType::Float => ConstantType::Float64, + AttrType::Varchar => ConstantType::Utf8String, + AttrType::Boolean => ConstantType::Bool, + } + } +} + +#[derive(Clone, Debug)] +pub struct ConstantPred(pub ArcPredicateNode); + +impl ConstantPred { + pub fn new(value: Value) -> Self { + let typ = ConstantType::get_data_type_from_value(&value); + Self::new_with_type(value, typ) + } + + pub fn new_with_type(value: Value, typ: ConstantType) -> Self { + ConstantPred( + PredicateNode { + typ: PredicateType::Constant(typ), + children: vec![], + data: Some(value), + } + .into(), + ) + } + + pub fn bool(value: bool) -> Self { + Self::new_with_type(Value::Bool(value), ConstantType::Bool) + } + + pub fn string(value: impl AsRef) -> Self { + Self::new_with_type( + Value::String(value.as_ref().into()), + ConstantType::Utf8String, + ) + } + + pub fn uint8(value: u8) -> Self { + Self::new_with_type(Value::UInt8(value), ConstantType::UInt8) + } + + pub fn uint16(value: u16) -> Self { + Self::new_with_type(Value::UInt16(value), ConstantType::UInt16) + } + + pub fn uint32(value: u32) -> Self { + Self::new_with_type(Value::UInt32(value), ConstantType::UInt32) + } + + pub fn uint64(value: u64) -> Self { + Self::new_with_type(Value::UInt64(value), ConstantType::UInt64) + } + + pub fn int8(value: i8) -> Self { + Self::new_with_type(Value::Int8(value), ConstantType::Int8) + } + + pub fn int16(value: i16) -> Self { + Self::new_with_type(Value::Int16(value), ConstantType::Int16) + } + + pub fn int32(value: i32) -> Self { + Self::new_with_type(Value::Int32(value), ConstantType::Int32) + } + + pub fn int64(value: i64) -> Self { + Self::new_with_type(Value::Int64(value), ConstantType::Int64) + } + + pub fn interval_month_day_nano(value: i128) -> Self { + Self::new_with_type(Value::Int128(value), ConstantType::IntervalMonthDateNano) + } + + pub fn float64(value: f64) -> Self { + Self::new_with_type( + Value::Float(SerializableOrderedF64(value.into())), + ConstantType::Float64, + ) + } + + pub fn date(value: i64) -> Self { + Self::new_with_type(Value::Int64(value), ConstantType::Date) + } + + pub fn decimal(value: f64) -> Self { + Self::new_with_type( + Value::Float(SerializableOrderedF64(value.into())), + ConstantType::Decimal, + ) + } + + pub fn serialized(value: Arc<[u8]>) -> Self { + Self::new_with_type(Value::Serialized(value), ConstantType::Binary) + } + + /// Gets the constant value. + pub fn value(&self) -> Value { + self.0.data.clone().unwrap() + } + + pub fn constant_type(&self) -> ConstantType { + if let PredicateType::Constant(typ) = self.0.typ { + typ + } else { + panic!("not a constant") + } + } +} + +impl ReprPredicateNode for ConstantPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(rel_node: ArcPredicateNode) -> Option { + if let PredicateType::Constant(_) = rel_node.typ { + Some(Self(rel_node)) + } else { + None + } + } +} diff --git a/optd-cost-model/src/common/predicates/data_type_pred.rs b/optd-cost-model/src/common/predicates/data_type_pred.rs new file mode 100644 index 0000000..fe29336 --- /dev/null +++ b/optd-cost-model/src/common/predicates/data_type_pred.rs @@ -0,0 +1,40 @@ +use arrow_schema::DataType; + +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +#[derive(Clone, Debug)] +pub struct DataTypePred(pub ArcPredicateNode); + +impl DataTypePred { + pub fn new(typ: DataType) -> Self { + DataTypePred( + PredicateNode { + typ: PredicateType::DataType(typ), + children: vec![], + data: None, + } + .into(), + ) + } + + pub fn data_type(&self) -> DataType { + if let PredicateType::DataType(ref data_type) = self.0.typ { + data_type.clone() + } else { + panic!("not a data type") + } + } +} + +impl ReprPredicateNode for DataTypePred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::DataType(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/in_list_pred.rs b/optd-cost-model/src/common/predicates/in_list_pred.rs new file mode 100644 index 0000000..8d3b511 --- /dev/null +++ b/optd-cost-model/src/common/predicates/in_list_pred.rs @@ -0,0 +1,48 @@ +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::Value, +}; + +use super::list_pred::ListPred; + +#[derive(Clone, Debug)] +pub struct InListPred(pub ArcPredicateNode); + +impl InListPred { + pub fn new(child: ArcPredicateNode, list: ListPred, negated: bool) -> Self { + InListPred( + PredicateNode { + typ: PredicateType::InList, + children: vec![child, list.into_pred_node()], + data: Some(Value::Bool(negated)), + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn list(&self) -> ListPred { + ListPred::from_pred_node(self.0.child(1)).unwrap() + } + + /// `true` for `NOT IN`. + pub fn negated(&self) -> bool { + self.0.data.as_ref().unwrap().as_bool() + } +} + +impl ReprPredicateNode for InListPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::InList) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/like_pred.rs b/optd-cost-model/src/common/predicates/like_pred.rs new file mode 100644 index 0000000..bf9fe31 --- /dev/null +++ b/optd-cost-model/src/common/predicates/like_pred.rs @@ -0,0 +1,66 @@ +use std::sync::Arc; + +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::Value, +}; + +#[derive(Clone, Debug)] +pub struct LikePred(pub ArcPredicateNode); + +impl LikePred { + pub fn new( + negated: bool, + case_insensitive: bool, + child: ArcPredicateNode, + pattern: ArcPredicateNode, + ) -> Self { + // TODO: support multiple values in data. + let negated = if negated { 1 } else { 0 }; + let case_insensitive = if case_insensitive { 1 } else { 0 }; + LikePred( + PredicateNode { + typ: PredicateType::Like, + children: vec![child.into_pred_node(), pattern.into_pred_node()], + data: Some(Value::Serialized(Arc::new([negated, case_insensitive]))), + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn pattern(&self) -> ArcPredicateNode { + self.0.child(1) + } + + /// `true` for `NOT LIKE`. + pub fn negated(&self) -> bool { + match self.0.data.as_ref().unwrap() { + Value::Serialized(data) => data[0] != 0, + _ => panic!("not a serialized value"), + } + } + + pub fn case_insensitive(&self) -> bool { + match self.0.data.as_ref().unwrap() { + Value::Serialized(data) => data[1] != 0, + _ => panic!("not a serialized value"), + } + } +} + +impl ReprPredicateNode for LikePred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::Like) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/list_pred.rs b/optd-cost-model/src/common/predicates/list_pred.rs new file mode 100644 index 0000000..972598d --- /dev/null +++ b/optd-cost-model/src/common/predicates/list_pred.rs @@ -0,0 +1,47 @@ +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +#[derive(Clone, Debug)] +pub struct ListPred(pub ArcPredicateNode); + +impl ListPred { + pub fn new(preds: Vec) -> Self { + ListPred( + PredicateNode { + typ: PredicateType::List, + children: preds, + data: None, + } + .into(), + ) + } + + /// Gets number of expressions in the list + pub fn len(&self) -> usize { + self.0.children.len() + } + + pub fn is_empty(&self) -> bool { + self.0.children.is_empty() + } + + pub fn child(&self, idx: usize) -> ArcPredicateNode { + self.0.child(idx) + } + + pub fn to_vec(&self) -> Vec { + self.0.children.clone() + } +} + +impl ReprPredicateNode for ListPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if pred_node.typ != PredicateType::List { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/log_op_pred.rs b/optd-cost-model/src/common/predicates/log_op_pred.rs index 88c5746..1899cb1 100644 --- a/optd-cost-model/src/common/predicates/log_op_pred.rs +++ b/optd-cost-model/src/common/predicates/log_op_pred.rs @@ -1,5 +1,9 @@ use std::fmt::Display; +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +use super::list_pred::ListPred; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum LogOpType { @@ -12,3 +16,70 @@ impl Display for LogOpType { write!(f, "{:?}", self) } } + +#[derive(Clone, Debug)] +pub struct LogOpPred(pub ArcPredicateNode); + +impl LogOpPred { + pub fn new(op_type: LogOpType, preds: Vec) -> Self { + LogOpPred( + PredicateNode { + typ: PredicateType::LogOp(op_type), + children: preds, + data: None, + } + .into(), + ) + } + + /// flatten_nested_logical is a helper function to flatten nested logical operators with same op + /// type eg. (a AND (b AND c)) => ExprList([a, b, c]) + /// (a OR (b OR c)) => ExprList([a, b, c]) + /// It assume the children of the input expr_list are already flattened + /// and can only be used in bottom up manner + pub fn new_flattened_nested_logical(op: LogOpType, expr_list: ListPred) -> Self { + // Since we assume that we are building the children bottom up, + // there is no need to call flatten_nested_logical recursively + let mut new_expr_list = Vec::new(); + for child in expr_list.to_vec() { + if let PredicateType::LogOp(child_op) = child.typ { + if child_op == op { + let child_log_op_expr = LogOpPred::from_pred_node(child).unwrap(); + new_expr_list.extend(child_log_op_expr.children().to_vec()); + continue; + } + } + new_expr_list.push(child.clone()); + } + LogOpPred::new(op, new_expr_list) + } + + pub fn children(&self) -> Vec { + self.0.children.clone() + } + + pub fn child(&self, idx: usize) -> ArcPredicateNode { + self.0.child(idx) + } + + pub fn op_type(&self) -> LogOpType { + if let PredicateType::LogOp(op_type) = self.0.typ { + op_type + } else { + panic!("not a log op") + } + } +} + +impl ReprPredicateNode for LogOpPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::LogOp(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/mod.rs b/optd-cost-model/src/common/predicates/mod.rs index 87e6e94..40c64cf 100644 --- a/optd-cost-model/src/common/predicates/mod.rs +++ b/optd-cost-model/src/common/predicates/mod.rs @@ -1,6 +1,12 @@ +pub mod attr_index_pred; pub mod bin_op_pred; +pub mod cast_pred; pub mod constant_pred; +pub mod data_type_pred; pub mod func_pred; +pub mod in_list_pred; +pub mod like_pred; +pub mod list_pred; pub mod log_op_pred; pub mod sort_order_pred; pub mod un_op_pred; diff --git a/optd-cost-model/src/common/predicates/un_op_pred.rs b/optd-cost-model/src/common/predicates/un_op_pred.rs index d33158f..a3fc270 100644 --- a/optd-cost-model/src/common/predicates/un_op_pred.rs +++ b/optd-cost-model/src/common/predicates/un_op_pred.rs @@ -1,5 +1,7 @@ use std::fmt::Display; +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum UnOpType { @@ -12,3 +14,44 @@ impl Display for UnOpType { write!(f, "{:?}", self) } } + +#[derive(Clone, Debug)] +pub struct UnOpPred(pub ArcPredicateNode); + +impl UnOpPred { + pub fn new(child: ArcPredicateNode, op_type: UnOpType) -> Self { + UnOpPred( + PredicateNode { + typ: PredicateType::UnOp(op_type), + children: vec![child], + data: None, + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn op_type(&self) -> UnOpType { + if let PredicateType::UnOp(op_type) = self.0.typ { + op_type + } else { + panic!("not a un op") + } + } +} + +impl ReprPredicateNode for UnOpPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::UnOp(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/properties/attr_ref.rs b/optd-cost-model/src/common/properties/attr_ref.rs index eb10fbb..d6105b6 100644 --- a/optd-cost-model/src/common/properties/attr_ref.rs +++ b/optd-cost-model/src/common/properties/attr_ref.rs @@ -23,6 +23,10 @@ pub enum AttrRef { } impl AttrRef { + pub fn new_base_table_attr_ref(table_id: TableId, attr_idx: u64) -> Self { + AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) + } + pub fn base_table_attr_ref(table_id: TableId, attr_idx: u64) -> Self { AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) } @@ -161,9 +165,9 @@ impl SemanticCorrelation { } /// [`GroupAttrRefs`] represents the attributes of a group in a query. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct GroupAttrRefs { - attribute_refs: AttrRefs, + attr_refs: AttrRefs, /// Correlation of the output attributes of the group. output_correlation: Option, } @@ -171,13 +175,13 @@ pub struct GroupAttrRefs { impl GroupAttrRefs { pub fn new(attribute_refs: AttrRefs, output_correlation: Option) -> Self { Self { - attribute_refs, + attr_refs: attribute_refs, output_correlation, } } - pub fn base_table_attribute_refs(&self) -> &AttrRefs { - &self.attribute_refs + pub fn attr_refs(&self) -> &AttrRefs { + &self.attr_refs } pub fn output_correlation(&self) -> Option<&SemanticCorrelation> { diff --git a/optd-cost-model/src/common/properties/mod.rs b/optd-cost-model/src/common/properties/mod.rs index c9acbd1..a90d634 100644 --- a/optd-cost-model/src/common/properties/mod.rs +++ b/optd-cost-model/src/common/properties/mod.rs @@ -21,3 +21,21 @@ impl std::fmt::Display for Attribute { } } } + +impl Attribute { + pub fn new(name: String, typ: ConstantType, nullable: bool) -> Self { + Self { + name, + typ, + nullable, + } + } + + pub fn new_non_null_int64(name: String) -> Self { + Self { + name, + typ: ConstantType::Int64, + nullable: false, + } + } +} diff --git a/optd-cost-model/src/common/properties/schema.rs b/optd-cost-model/src/common/properties/schema.rs index 4ee4fce..d25a23a 100644 --- a/optd-cost-model/src/common/properties/schema.rs +++ b/optd-cost-model/src/common/properties/schema.rs @@ -33,3 +33,9 @@ impl Schema { self.len() == 0 } } + +impl From> for Schema { + fn from(attributes: Vec) -> Self { + Self::new(attributes) + } +} diff --git a/optd-cost-model/src/common/types.rs b/optd-cost-model/src/common/types.rs index 1e92355..fecd143 100644 --- a/optd-cost-model/src/common/types.rs +++ b/optd-cost-model/src/common/types.rs @@ -1,24 +1,27 @@ use std::fmt::Display; +/// TODO: Implement from and to methods for the following types to enable conversion +/// to and from their persistent counterparts. + /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct GroupId(pub usize); +pub struct GroupId(pub u64); /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct ExprId(pub usize); +pub struct ExprId(pub u64); /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct TableId(pub usize); +pub struct TableId(pub u64); /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct AttrId(pub usize); +pub struct AttrId(pub u64); /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct EpochId(pub usize); +pub struct EpochId(pub u64); impl Display for GroupId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {