@@ -10,7 +10,7 @@ use super::PersistentMemo;
1010use crate :: {
1111 entities:: * ,
1212 expression:: { LogicalExpression , PhysicalExpression } ,
13- memo:: { GroupId , LogicalExpressionId , MemoError , PhysicalExpressionId } ,
13+ memo:: { GroupId , GroupStatus , LogicalExpressionId , MemoError , PhysicalExpressionId } ,
1414 OptimizerResult , DATABASE_URL ,
1515} ;
1616use sea_orm:: {
@@ -66,6 +66,40 @@ impl PersistentMemo {
6666 . ok_or ( MemoError :: UnknownGroup ( group_id) ) ?)
6767 }
6868
69+ /// Retrieves the root / canonical group ID of the given group ID.
70+ ///
71+ /// The groups form a union find / disjoint set parent pointer forest, where group merging
72+ /// causes two trees to merge.
73+ ///
74+ /// This function uses the path compression optimization, which amortizes the cost to a single
75+ /// lookup (theoretically in constant time, but we must be wary of the I/O roundtrip).
76+ pub async fn get_root_group ( & self , group_id : GroupId ) -> OptimizerResult < GroupId > {
77+ let mut curr_group = self . get_group ( group_id) . await ?;
78+
79+ // Traverse up the path and find the root group, keeping track of groups we have visited.
80+ let mut path = vec ! [ ] ;
81+ loop {
82+ let Some ( parent_id) = curr_group. parent_id else {
83+ break ;
84+ } ;
85+
86+ let next_group = self . get_group ( GroupId ( parent_id) ) . await ?;
87+ path. push ( curr_group) ;
88+ curr_group = next_group;
89+ }
90+
91+ let root_id = GroupId ( curr_group. id ) ;
92+
93+ // Path Compression Optimization:
94+ // For every group along the path that we walked, set their parent id pointer to the root.
95+ // This allows for an amortized O(1) cost for `get_root_group`.
96+ for group in path {
97+ self . update_group_parent ( GroupId ( group. id ) , root_id) . await ?;
98+ }
99+
100+ Ok ( root_id)
101+ }
102+
69103 /// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`].
70104 ///
71105 /// If the physical expression does not exist, returns a
@@ -146,6 +180,32 @@ impl PersistentMemo {
146180 Ok ( children)
147181 }
148182
183+ /// Updates / replaces a group's status. Returns the previous group status.
184+ ///
185+ /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
186+ pub async fn update_group_status (
187+ & self ,
188+ group_id : GroupId ,
189+ status : GroupStatus ,
190+ ) -> OptimizerResult < GroupStatus > {
191+ // First retrieve the group record.
192+ let mut group = self . get_group ( group_id) . await ?. into_active_model ( ) ;
193+
194+ // Update the group's status.
195+ let old_status = group. status ;
196+ group. status = Set ( status as u8 as i8 ) ;
197+ group. update ( & self . db ) . await ?;
198+
199+ let old_status = match old_status. unwrap ( ) {
200+ 0 => GroupStatus :: InProgress ,
201+ 1 => GroupStatus :: Explored ,
202+ 2 => GroupStatus :: Optimized ,
203+ _ => panic ! ( "encountered an invalid group status" ) ,
204+ } ;
205+
206+ Ok ( old_status)
207+ }
208+
149209 /// Updates / replaces a group's best physical plan (winner). Optionally returns the previous
150210 /// winner's physical expression ID.
151211 ///
@@ -167,8 +227,32 @@ impl PersistentMemo {
167227 group. update ( & self . db ) . await ?;
168228
169229 // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
170- let old = old_id. unwrap ( ) . map ( PhysicalExpressionId ) ;
171- Ok ( old)
230+ let old_id = old_id. unwrap ( ) . map ( PhysicalExpressionId ) ;
231+ Ok ( old_id)
232+ }
233+
234+ /// Updates / replaces a group's parent group. Optionally returns the previous parent.
235+ ///
236+ /// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error.
237+ pub async fn update_group_parent (
238+ & self ,
239+ group_id : GroupId ,
240+ parent_id : GroupId ,
241+ ) -> OptimizerResult < Option < GroupId > > {
242+ // First retrieve the group record.
243+ let mut group = self . get_group ( group_id) . await ?. into_active_model ( ) ;
244+
245+ // Check that the parent group exists.
246+ let _ = self . get_group ( parent_id) . await ?;
247+
248+ // Update the group to point to the new parent.
249+ let old_parent = group. parent_id ;
250+ group. parent_id = Set ( Some ( parent_id. 0 ) ) ;
251+ group. update ( & self . db ) . await ?;
252+
253+ // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
254+ let old_parent = old_parent. unwrap ( ) . map ( GroupId ) ;
255+ Ok ( old_parent)
172256 }
173257
174258 /// Adds a logical expression to an existing group via its ID.
@@ -192,10 +276,10 @@ impl PersistentMemo {
192276 group_id : GroupId ,
193277 logical_expression : LogicalExpression ,
194278 children : & [ GroupId ] ,
195- ) -> OptimizerResult < Result < LogicalExpressionId , LogicalExpressionId > > {
279+ ) -> OptimizerResult < Result < LogicalExpressionId , ( GroupId , LogicalExpressionId ) > > {
196280 // Check if the expression already exists anywhere in the memo table.
197281 if let Some ( existing_id) = self
198- . is_duplicate_logical_expression ( & logical_expression)
282+ . is_duplicate_logical_expression ( & logical_expression, children )
199283 . await ?
200284 {
201285 return Ok ( Err ( existing_id) ) ;
@@ -227,7 +311,15 @@ impl PersistentMemo {
227311 // Finally, insert the fingerprint of the logical expression as well.
228312 let new_expr: LogicalExpression = new_model. into ( ) ;
229313 let kind = new_expr. kind ( ) ;
230- let hash = new_expr. fingerprint ( ) ;
314+
315+ // In order to calculate a correct fingerprint, we will want to use the IDs of the root
316+ // groups of the children instead of the child ID themselves.
317+ let mut rewrites = vec ! [ ] ;
318+ for & child_id in children {
319+ let root_id = self . get_root_group ( child_id) . await ?;
320+ rewrites. push ( ( child_id, root_id) ) ;
321+ }
322+ let hash = new_expr. fingerprint_with_rewrite ( & rewrites) ;
231323
232324 let fingerprint = fingerprint:: ActiveModel {
233325 id : NotSet ,
@@ -285,8 +377,8 @@ impl PersistentMemo {
285377 /// In order to prevent a large amount of duplicate work, the memo table must support duplicate
286378 /// expression detection.
287379 ///
288- /// Returns `Some(expression_id)` if the memo table detects that the expression already exists,
289- /// and `None` otherwise.
380+ /// Returns `Some((group_id, expression_id)) ` if the memo table detects that the expression
381+ /// already exists, and `None` otherwise.
290382 ///
291383 /// This function assumes that the child groups of the expression are currently roots of their
292384 /// group sets. For example, if G1 and G2 should be merged, and G1 is the root, then the input
@@ -296,13 +388,22 @@ impl PersistentMemo {
296388 pub async fn is_duplicate_logical_expression (
297389 & self ,
298390 logical_expression : & LogicalExpression ,
299- ) -> OptimizerResult < Option < LogicalExpressionId > > {
391+ children : & [ GroupId ] ,
392+ ) -> OptimizerResult < Option < ( GroupId , LogicalExpressionId ) > > {
300393 let model: logical_expression:: Model = logical_expression. clone ( ) . into ( ) ;
301394
302395 // Lookup all expressions that have the same fingerprint and kind. There may be false
303396 // positives, but we will check for those next.
304397 let kind = model. kind ;
305- let fingerprint = logical_expression. fingerprint ( ) ;
398+
399+ // In order to calculate a correct fingerprint, we will want to use the IDs of the root
400+ // groups of the children instead of the child ID themselves.
401+ let mut rewrites = vec ! [ ] ;
402+ for & child_id in children {
403+ let root_id = self . get_root_group ( child_id) . await ?;
404+ rewrites. push ( ( child_id, root_id) ) ;
405+ }
406+ let fingerprint = logical_expression. fingerprint_with_rewrite ( & rewrites) ;
306407
307408 // Filter first by the fingerprint, and then the kind.
308409 // FIXME: The kind is already embedded into the fingerprint, so we may not actually need the
@@ -323,11 +424,11 @@ impl PersistentMemo {
323424 let mut match_id = None ;
324425 for potential_match in potential_matches {
325426 let expr_id = LogicalExpressionId ( potential_match. logical_expression_id ) ;
326- let ( _ , expr) = self . get_logical_expression ( expr_id) . await ?;
427+ let ( group_id , expr) = self . get_logical_expression ( expr_id) . await ?;
327428
328429 // Check for an exact match.
329430 if & expr == logical_expression {
330- match_id = Some ( expr_id) ;
431+ match_id = Some ( ( group_id , expr_id) ) ;
331432
332433 // There should be at most one duplicate expression, so we can break here.
333434 break ;
@@ -359,18 +460,17 @@ impl PersistentMemo {
359460 ) -> OptimizerResult < Result < ( GroupId , LogicalExpressionId ) , ( GroupId , LogicalExpressionId ) > >
360461 {
361462 // Check if the expression already exists in the memo table.
362- if let Some ( existing_id) = self
363- . is_duplicate_logical_expression ( & logical_expression)
463+ if let Some ( ( group_id , existing_id) ) = self
464+ . is_duplicate_logical_expression ( & logical_expression, children )
364465 . await ?
365466 {
366- let ( group_id, _expr) = self . get_logical_expression ( existing_id) . await ?;
367467 return Ok ( Err ( ( group_id, existing_id) ) ) ;
368468 }
369469
370470 // The expression does not exist yet, so we need to create a new group and new expression.
371471 let group = cascades_group:: ActiveModel {
372472 winner : Set ( None ) ,
373- is_optimized : Set ( false ) ,
473+ status : Set ( 0 ) , // `GroupStatus::InProgress` status.
374474 ..Default :: default ( )
375475 } ;
376476
@@ -401,7 +501,15 @@ impl PersistentMemo {
401501 // Finally, insert the fingerprint of the logical expression as well.
402502 let new_expr: LogicalExpression = new_model. into ( ) ;
403503 let kind = new_expr. kind ( ) ;
404- let hash = new_expr. fingerprint ( ) ;
504+
505+ // In order to calculate a correct fingerprint, we will want to use the IDs of the root
506+ // groups of the children instead of the child ID themselves.
507+ let mut rewrites = vec ! [ ] ;
508+ for & child_id in children {
509+ let root_id = self . get_root_group ( child_id) . await ?;
510+ rewrites. push ( ( child_id, root_id) ) ;
511+ }
512+ let hash = new_expr. fingerprint_with_rewrite ( & rewrites) ;
405513
406514 let fingerprint = fingerprint:: ActiveModel {
407515 id : NotSet ,
0 commit comments