Skip to content

Commit 4f78b8a

Browse files
mwillseyMazterQyou
authored andcommitted
perf(cubesql): Improve rewrite engine performance
1 parent dd1e953 commit 4f78b8a

File tree

4 files changed

+155
-75
lines changed

4 files changed

+155
-75
lines changed

packages/cubejs-backend-native/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/cubesql/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/cubesql/cubesql/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ nanoid = "0.3.0"
4848
tokio-util = { version = "0.6.2", features=["compat"] }
4949
comfy-table = "7.1.0"
5050
bitflags = "1.3.2"
51-
egg = { rev = "bdf05cee0a145a524fe8c6c33aa577ac50ace7c9", git = "https://github.com/cube-js/egg.git" }
51+
egg = { rev = "58c2586473360f0821e91ef196b55070ac1afedc", git = "https://github.com/cube-js/egg.git" }
5252
paste = "1.0.6"
5353
csv = "1.1.6"
5454
tracing = { version = "0.1.40", features = ["async-await"] }

rust/cubesql/cubesql/src/compile/rewrite/mod.rs

Lines changed: 152 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use egg::{
2727
};
2828
use itertools::Itertools;
2929
use std::{
30+
borrow::Cow,
3031
fmt::{self, Display, Formatter},
3132
ops::Index,
3233
slice::Iter,
@@ -737,7 +738,31 @@ where
737738
.unwrap()
738739
}
739740

740-
type ListMatches = Vec<Subst>;
741+
struct ListMatches {
742+
len: usize,
743+
substs: Vec<Subst>,
744+
prevs: Vec<usize>,
745+
start: usize,
746+
}
747+
impl ListMatches {
748+
fn range(&self) -> std::ops::Range<usize> {
749+
self.start..self.substs.len()
750+
}
751+
fn for_each(&self, mut f: impl FnMut(&[&Subst])) {
752+
let mut substs = Vec::with_capacity(self.len);
753+
for i in self.range() {
754+
substs.clear();
755+
let mut i = i;
756+
while i != usize::MAX {
757+
substs.push(&self.substs[i]);
758+
i = self.prevs[i];
759+
}
760+
substs.reverse();
761+
assert_eq!(substs.len(), self.len);
762+
f(&substs);
763+
}
764+
}
765+
}
741766

742767
#[derive(Clone, PartialEq)]
743768
pub enum ListType {
@@ -847,6 +872,68 @@ impl ListNodeSearcher {
847872
}
848873
}
849874
}
875+
876+
fn search_from_list_matches<'a>(
877+
&'a self,
878+
egraph: &EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>,
879+
limit: usize,
880+
list_subst: &Subst,
881+
output: &mut Vec<Subst>,
882+
) {
883+
let list_id = list_subst[self.list_var];
884+
for node in egraph[list_id].iter() {
885+
let list_children = node.children();
886+
if !self.match_node(node) || list_children.is_empty() {
887+
continue;
888+
}
889+
890+
let mut list_matches = ListMatches {
891+
len: list_children.len(),
892+
substs: vec![],
893+
prevs: vec![],
894+
start: 0,
895+
};
896+
897+
list_matches.substs = self
898+
.elem_pattern
899+
.search_eclass_with_limit(egraph, list_children[0], limit)
900+
.map_or(vec![], |ms| ms.substs);
901+
902+
list_matches.prevs = vec![usize::MAX; list_matches.substs.len()];
903+
904+
let agree = |subst1: &Subst, subst2: &Subst| {
905+
self.top_level_elem_vars
906+
.iter()
907+
.all(|&v| subst1.get(v) == subst2.get(v))
908+
};
909+
910+
for &list_child in &list_children[1..] {
911+
debug_assert_eq!(list_matches.substs.len(), list_matches.prevs.len());
912+
let range = list_matches.range();
913+
if range.is_empty() {
914+
break;
915+
}
916+
list_matches.start = list_matches.substs.len();
917+
self.elem_pattern
918+
.search_eclass_with_fn(egraph, list_child, |subst| {
919+
for i in range.clone() {
920+
if agree(&list_matches.substs[i], subst) {
921+
list_matches.substs.push(subst.clone());
922+
list_matches.prevs.push(i);
923+
}
924+
}
925+
Ok(())
926+
})
927+
.unwrap_or_default();
928+
}
929+
930+
if !list_matches.range().is_empty() {
931+
let mut subst = list_subst.clone();
932+
subst.data = Some(Arc::new(list_matches));
933+
output.push(subst);
934+
}
935+
}
936+
}
850937
}
851938

852939
impl Searcher<LogicalPlanLanguage, LogicalPlanAnalysis> for ListNodeSearcher {
@@ -856,55 +943,48 @@ impl Searcher<LogicalPlanLanguage, LogicalPlanAnalysis> for ListNodeSearcher {
856943
eclass: Id,
857944
limit: usize,
858945
) -> Option<SearchMatches<LogicalPlanLanguage>> {
859-
let mut matches = self
860-
.list_pattern
861-
.search_eclass_with_limit(egraph, eclass, limit)?;
862-
863-
let mut new_substs: Vec<Subst> = vec![];
864-
for subst in matches.substs {
865-
let list_id = subst[self.list_var];
866-
for node in egraph[list_id].iter() {
867-
let list_children = node.children();
868-
if !self.match_node(node) || list_children.is_empty() {
869-
continue;
870-
}
946+
let mut matches = SearchMatches {
947+
substs: vec![],
948+
eclass,
949+
ast: Some(Cow::Borrowed(&self.list_pattern.ast)),
950+
};
951+
self.list_pattern
952+
.search_eclass_with_fn(egraph, eclass, |subst| {
953+
self.search_from_list_matches(egraph, limit, subst, &mut matches.substs);
954+
Ok(())
955+
})
956+
.unwrap_or_default();
871957

872-
let matches_product = list_children
873-
.iter()
874-
.map(|child| {
875-
self.elem_pattern
876-
.search_eclass_with_limit(egraph, *child, limit)
877-
.map_or(vec![], |ms| ms.substs)
878-
})
879-
.multi_cartesian_product();
880-
881-
// TODO(mwillsey) this could be optimized more by filtering the
882-
// matches as you go
883-
for list_matches in matches_product {
884-
let subst0 = &list_matches[0];
885-
let agree_with_top_level = list_matches.iter().all(|m| {
886-
self.top_level_elem_vars
887-
.iter()
888-
.all(|&v| m.get(v) == subst0.get(v))
889-
});
890-
891-
if agree_with_top_level {
892-
let mut subst = subst.clone();
893-
assert_eq!(list_matches.len(), list_children.len());
894-
for &var in &self.top_level_elem_vars {
895-
if let Some(id) = list_matches[0].get(var) {
896-
subst.insert(var, *id);
897-
}
898-
}
899-
subst.data = Some(Arc::new(list_matches));
900-
new_substs.push(subst);
958+
(!matches.substs.is_empty()).then(|| matches)
959+
}
960+
961+
fn search_with_limit(
962+
&self,
963+
egraph: &EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>,
964+
limit: usize,
965+
) -> Vec<SearchMatches<LogicalPlanLanguage>> {
966+
let mut result: Vec<SearchMatches<_>> = vec![];
967+
self.list_pattern
968+
.search_with_fn(egraph, |id, list_subst| {
969+
let last = match result.last_mut() {
970+
Some(top) if top.eclass == id => top,
971+
_ => {
972+
result.push(SearchMatches {
973+
substs: vec![],
974+
eclass: id,
975+
ast: Some(Cow::Borrowed(&self.list_pattern.ast)),
976+
});
977+
result.last_mut().unwrap()
901978
}
902-
}
903-
}
904-
}
979+
};
980+
debug_assert_eq!(last.eclass, id);
981+
self.search_from_list_matches(egraph, limit, list_subst, &mut last.substs);
982+
Ok(())
983+
})
984+
.unwrap_or_default();
905985

906-
matches.substs = new_substs;
907-
(!matches.substs.is_empty()).then(|| matches)
986+
result.retain(|matches| !matches.substs.is_empty());
987+
result
908988
}
909989

910990
fn vars(&self) -> Vec<Var> {
@@ -999,42 +1079,42 @@ impl Applier<LogicalPlanLanguage, LogicalPlanAnalysis> for ListNodeApplier {
9991079
fn apply_one(
10001080
&self,
10011081
egraph: &mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>,
1002-
eclass: Id,
1082+
mut eclass: Id,
10031083
subst: &Subst,
10041084
_searcher_ast: Option<&PatternAst<LogicalPlanLanguage>>,
10051085
_rule_name: Symbol,
10061086
) -> Vec<Id> {
1007-
let mut subst = subst.clone();
1008-
10091087
let data = subst
10101088
.data
10111089
.as_ref()
10121090
.expect("no data, did you use ListNodeSearcher?");
1013-
let list_matches = data
1014-
.downcast_ref::<ListMatches>()
1015-
.expect("wrong data type")
1016-
.clone();
1017-
1018-
for list in &self.lists {
1019-
let new_list = list_matches
1020-
.iter()
1021-
.map(|list_subst| {
1022-
let mut subst = subst.clone();
1023-
subst.extend(list_subst.iter());
1024-
egraph.add_instantiation(&list.elem_pattern, &subst)
1025-
})
1026-
.collect();
1091+
let list_matches = data.downcast_ref::<ListMatches>().expect("wrong data type");
10271092

1028-
subst.insert(list.new_list_var, egraph.add(list.make_node(new_list)));
1029-
}
1093+
let mut subst = subst.clone();
1094+
let mut result_ids = vec![];
1095+
list_matches.for_each(|list_substs| {
1096+
for list in &self.lists {
1097+
let new_list = list_substs
1098+
.iter()
1099+
.map(|list_subst| {
1100+
let mut subst = subst.clone();
1101+
subst.extend(list_subst.iter());
1102+
egraph.add_instantiation(&list.elem_pattern, &subst)
1103+
})
1104+
.collect();
10301105

1031-
let result_id = egraph.add_instantiation(&self.list_pattern, &subst);
1106+
subst.insert(list.new_list_var, egraph.add(list.make_node(new_list)));
1107+
}
1108+
let mut subst = subst.clone();
1109+
subst.extend(list_substs[0].iter());
1110+
let new_id = egraph.add_instantiation(&self.list_pattern, &subst);
1111+
if egraph.union(eclass, new_id) {
1112+
result_ids.push(new_id);
1113+
eclass = new_id;
1114+
}
1115+
});
10321116

1033-
if egraph.union(eclass, result_id) {
1034-
vec![result_id]
1035-
} else {
1036-
vec![]
1037-
}
1117+
result_ids
10381118
}
10391119

10401120
fn vars(&self) -> Vec<Var> {

0 commit comments

Comments
 (0)