@@ -10,7 +10,7 @@ use super::PersistentMemo;
1010use crate :: {
1111 entities:: * ,
1212 expression:: { LogicalExpression , PhysicalExpression } ,
13- memo:: { GroupId , LogicalExpressionId , MemoError , PhysicalExpressionId } ,
13+ memo:: { GroupId , GroupStatus , LogicalExpressionId , MemoError , PhysicalExpressionId } ,
1414 OptimizerResult , DATABASE_URL ,
1515} ;
1616use sea_orm:: {
@@ -66,6 +66,40 @@ impl PersistentMemo {
6666 . ok_or ( MemoError :: UnknownGroup ( group_id) ) ?)
6767 }
6868
69+ /// Retrieves the root / canonical group ID of the given group ID.
70+ ///
71+ /// The groups form a union find / disjoint set parent pointer forest, where group merging
72+ /// causes two trees to merge.
73+ ///
74+ /// This function uses the path compression optimization, which amortizes the cost to a single
75+ /// lookup (theoretically in constant time, but we must be wary of the I/O roundtrip).
76+ pub async fn get_root_group ( & self , group_id : GroupId ) -> OptimizerResult < GroupId > {
77+ let mut curr_group = self . get_group ( group_id) . await ?;
78+
79+ // Traverse up the path and find the root group, keeping track of groups we have visited.
80+ let mut path = vec ! [ ] ;
81+ loop {
82+ let Some ( parent_id) = curr_group. parent_id else {
83+ break ;
84+ } ;
85+
86+ let next_group = self . get_group ( GroupId ( parent_id) ) . await ?;
87+ path. push ( curr_group) ;
88+ curr_group = next_group;
89+ }
90+
91+ let root_id = GroupId ( curr_group. id ) ;
92+
93+ // Path Compression Optimization:
94+ // For every group along the path that we walked, set their parent id pointer to the root.
95+ // This allows for an amortized O(1) cost for `get_root_group`.
96+ for group in path {
97+ self . update_group_parent ( GroupId ( group. id ) , root_id) . await ?;
98+ }
99+
100+ Ok ( root_id)
101+ }
102+
69103 /// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`].
70104 ///
71105 /// If the physical expression does not exist, returns a
@@ -146,6 +180,32 @@ impl PersistentMemo {
146180 Ok ( children)
147181 }
148182
183+ /// Updates / replaces a group's status. Returns the previous group status.
184+ ///
185+ /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
186+ pub async fn update_group_status (
187+ & self ,
188+ group_id : GroupId ,
189+ status : GroupStatus ,
190+ ) -> OptimizerResult < GroupStatus > {
191+ // First retrieve the group record.
192+ let mut group = self . get_group ( group_id) . await ?. into_active_model ( ) ;
193+
194+ // Update the group's status.
195+ let old_status = group. status ;
196+ group. status = Set ( status as u8 as i8 ) ;
197+ group. update ( & self . db ) . await ?;
198+
199+ let old_status = match old_status. unwrap ( ) {
200+ 0 => GroupStatus :: InProgress ,
201+ 1 => GroupStatus :: Explored ,
202+ 2 => GroupStatus :: Optimized ,
203+ _ => panic ! ( "encountered an invalid group status" ) ,
204+ } ;
205+
206+ Ok ( old_status)
207+ }
208+
149209 /// Updates / replaces a group's best physical plan (winner). Optionally returns the previous
150210 /// winner's physical expression ID.
151211 ///
@@ -167,8 +227,45 @@ impl PersistentMemo {
167227 group. update ( & self . db ) . await ?;
168228
169229 // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
170- let old = old_id. unwrap ( ) . map ( PhysicalExpressionId ) ;
171- Ok ( old)
230+ let old_id = old_id. unwrap ( ) . map ( PhysicalExpressionId ) ;
231+ Ok ( old_id)
232+ }
233+
234+ /// Updates / replaces a group's parent group. Optionally returns the previous parent.
235+ ///
236+ /// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error.
237+ pub async fn update_group_parent (
238+ & self ,
239+ group_id : GroupId ,
240+ parent_id : GroupId ,
241+ ) -> OptimizerResult < Option < GroupId > > {
242+ // First retrieve the group record.
243+ let mut group = self . get_group ( group_id) . await ?. into_active_model ( ) ;
244+
245+ // Check that the parent group exists.
246+ let _ = self . get_group ( parent_id) . await ?;
247+
248+ // Update the group to point to the new parent.
249+ let old_parent = group. parent_id ;
250+ group. parent_id = Set ( Some ( parent_id. 0 ) ) ;
251+ group. update ( & self . db ) . await ?;
252+
253+ // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
254+ let old_parent = old_parent. unwrap ( ) . map ( GroupId ) ;
255+ Ok ( old_parent)
256+ }
257+
258+ /// Merges two groups sets together. Returns the new root group of the unioned sets.
259+ ///
260+ /// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error.
261+ ///
262+ /// TODO use union by rank / size as an optimization?
263+ pub async fn merge_groups ( & self , group1 : GroupId , group2 : GroupId ) -> OptimizerResult < GroupId > {
264+ // Without tracking the size of each of the groups, it is arbitrary which group is better to
265+ // merge into the other. So we will arbitrarily choose `group1` to merge into `group2`.
266+ self . update_group_parent ( group1, group2) . await ?;
267+
268+ Ok ( group2)
172269 }
173270
174271 /// Adds a logical expression to an existing group via its ID.
@@ -195,7 +292,7 @@ impl PersistentMemo {
195292 ) -> OptimizerResult < Result < LogicalExpressionId , LogicalExpressionId > > {
196293 // Check if the expression already exists anywhere in the memo table.
197294 if let Some ( existing_id) = self
198- . is_duplicate_logical_expression ( & logical_expression)
295+ . is_duplicate_logical_expression ( & logical_expression, children )
199296 . await ?
200297 {
201298 return Ok ( Err ( existing_id) ) ;
@@ -227,7 +324,15 @@ impl PersistentMemo {
227324 // Finally, insert the fingerprint of the logical expression as well.
228325 let new_expr: LogicalExpression = new_model. into ( ) ;
229326 let kind = new_expr. kind ( ) ;
230- let hash = new_expr. fingerprint ( ) ;
327+
328+ // In order to calculate a correct fingerprint, we will want to use the IDs of the root
329+ // groups of the children instead of the child ID themselves.
330+ let mut rewrites = vec ! [ ] ;
331+ for & child_id in children {
332+ let root_id = self . get_root_group ( child_id) . await ?;
333+ rewrites. push ( ( child_id, root_id) ) ;
334+ }
335+ let hash = new_expr. fingerprint_with_rewrite ( & rewrites) ;
231336
232337 let fingerprint = fingerprint:: ActiveModel {
233338 id : NotSet ,
@@ -296,13 +401,22 @@ impl PersistentMemo {
296401 pub async fn is_duplicate_logical_expression (
297402 & self ,
298403 logical_expression : & LogicalExpression ,
404+ children : & [ GroupId ] ,
299405 ) -> OptimizerResult < Option < LogicalExpressionId > > {
300406 let model: logical_expression:: Model = logical_expression. clone ( ) . into ( ) ;
301407
302408 // Lookup all expressions that have the same fingerprint and kind. There may be false
303409 // positives, but we will check for those next.
304410 let kind = model. kind ;
305- let fingerprint = logical_expression. fingerprint ( ) ;
411+
412+ // In order to calculate a correct fingerprint, we will want to use the IDs of the root
413+ // groups of the children instead of the child ID themselves.
414+ let mut rewrites = vec ! [ ] ;
415+ for & child_id in children {
416+ let root_id = self . get_root_group ( child_id) . await ?;
417+ rewrites. push ( ( child_id, root_id) ) ;
418+ }
419+ let fingerprint = logical_expression. fingerprint_with_rewrite ( & rewrites) ;
306420
307421 // Filter first by the fingerprint, and then the kind.
308422 // FIXME: The kind is already embedded into the fingerprint, so we may not actually need the
@@ -360,7 +474,7 @@ impl PersistentMemo {
360474 {
361475 // Check if the expression already exists in the memo table.
362476 if let Some ( existing_id) = self
363- . is_duplicate_logical_expression ( & logical_expression)
477+ . is_duplicate_logical_expression ( & logical_expression, children )
364478 . await ?
365479 {
366480 let ( group_id, _expr) = self . get_logical_expression ( existing_id) . await ?;
@@ -370,7 +484,7 @@ impl PersistentMemo {
370484 // The expression does not exist yet, so we need to create a new group and new expression.
371485 let group = cascades_group:: ActiveModel {
372486 winner : Set ( None ) ,
373- is_optimized : Set ( false ) ,
487+ status : Set ( 0 ) , // `GroupStatus::InProgress` status.
374488 ..Default :: default ( )
375489 } ;
376490
@@ -401,7 +515,15 @@ impl PersistentMemo {
401515 // Finally, insert the fingerprint of the logical expression as well.
402516 let new_expr: LogicalExpression = new_model. into ( ) ;
403517 let kind = new_expr. kind ( ) ;
404- let hash = new_expr. fingerprint ( ) ;
518+
519+ // In order to calculate a correct fingerprint, we will want to use the IDs of the root
520+ // groups of the children instead of the child ID themselves.
521+ let mut rewrites = vec ! [ ] ;
522+ for & child_id in children {
523+ let root_id = self . get_root_group ( child_id) . await ?;
524+ rewrites. push ( ( child_id, root_id) ) ;
525+ }
526+ let hash = new_expr. fingerprint_with_rewrite ( & rewrites) ;
405527
406528 let fingerprint = fingerprint:: ActiveModel {
407529 id : NotSet ,
0 commit comments