@@ -90,12 +90,55 @@ impl PersistentMemo {
9090 // For every group along the path that we walked, set their parent id pointer to the root.
9191 // This allows for an amortized O(1) cost for `get_root_group`.
9292 for group in path {
93- self . update_group_parent ( GroupId ( group. id ) , root_id) . await ?;
93+ let mut active_group = group. into_active_model ( ) ;
94+
95+ // Update the group to point to the new parent.
96+ active_group. parent_id = Set ( Some ( root_id. 0 ) ) ;
97+ active_group. update ( & self . db ) . await ?;
9498 }
9599
96100 Ok ( root_id)
97101 }
98102
103+ /// Retrieves every group ID of groups that share the same root group with the input group.
104+ ///
105+ /// If a group does not exist in the cycle, returns a [`MemoError::UnknownGroup`] error.
106+ ///
107+ /// The group records form a union-find data structure that also maintains a circular linked
108+ /// list in every set that allows us to iterate over all elements in a set in linear time.
109+ pub async fn get_group_set ( & self , group_id : GroupId ) -> OptimizerResult < Vec < GroupId > > {
110+ // Iterate over the circular linked list until we reach ourselves again.
111+ let base_group = self . get_group ( group_id) . await ?;
112+
113+ // The only case when `next_id` is set to `None` is if the current group is a root, which
114+ // means that this group is the only group in the set.
115+ if base_group. next_id . is_none ( ) {
116+ assert ! ( base_group. parent_id. is_none( ) ) ;
117+ return Ok ( vec ! [ group_id] ) ;
118+ }
119+
120+ // Iterate over the circular linked list until we see ourselves again, collecting nodes
121+ // along the way.
122+ let mut set = vec ! [ ] ;
123+ let mut next_id = base_group
124+ . next_id
125+ . expect ( "next pointer cannot be null if it is in a cycle" ) ;
126+ loop {
127+ let curr_group = self . get_group ( GroupId ( next_id) ) . await ?;
128+
129+ if curr_group. id == group_id. 0 {
130+ break ;
131+ }
132+
133+ set. push ( GroupId ( curr_group. id ) ) ;
134+ next_id = curr_group
135+ . next_id
136+ . expect ( "next pointer cannot be null if it is in a cycle" ) ;
137+ }
138+
139+ Ok ( set)
140+ }
141+
99142 /// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`].
100143 ///
101144 /// If the physical expression does not exist, returns a
@@ -227,30 +270,6 @@ impl PersistentMemo {
227270 Ok ( old_id)
228271 }
229272
230- /// Updates / replaces a group's parent group. Optionally returns the previous parent.
231- ///
232- /// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error.
233- pub async fn update_group_parent (
234- & self ,
235- group_id : GroupId ,
236- parent_id : GroupId ,
237- ) -> OptimizerResult < Option < GroupId > > {
238- // First retrieve the group record.
239- let mut group = self . get_group ( group_id) . await ?. into_active_model ( ) ;
240-
241- // Check that the parent group exists.
242- let _ = self . get_group ( parent_id) . await ?;
243-
244- // Update the group to point to the new parent.
245- let old_parent = group. parent_id ;
246- group. parent_id = Set ( Some ( parent_id. 0 ) ) ;
247- group. update ( & self . db ) . await ?;
248-
249- // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
250- let old_parent = old_parent. unwrap ( ) . map ( GroupId ) ;
251- Ok ( old_parent)
252- }
253-
254273 /// Adds a logical expression to an existing group via its ID.
255274 ///
256275 /// The caller is required to pass in a slice of [`GroupId`] that represent the child groups of
@@ -323,7 +342,7 @@ impl PersistentMemo {
323342 kind : Set ( kind) ,
324343 hash : Set ( hash) ,
325344 } ;
326- let _ = fingerprint:: Entity :: insert ( fingerprint)
345+ fingerprint:: Entity :: insert ( fingerprint)
327346 . exec ( & self . db )
328347 . await ?;
329348
@@ -513,10 +532,92 @@ impl PersistentMemo {
513532 kind : Set ( kind) ,
514533 hash : Set ( hash) ,
515534 } ;
516- let _ = fingerprint:: Entity :: insert ( fingerprint)
535+ fingerprint:: Entity :: insert ( fingerprint)
517536 . exec ( & self . db )
518537 . await ?;
519538
520539 Ok ( Ok ( ( GroupId ( group_id) , LogicalExpressionId ( expr_id) ) ) )
521540 }
541+
542+ /// Merges two groups sets together.
543+ ///
544+ /// If either of the input groups do not exist, returns a [`MemoError::UnknownGroup`] error.
545+ ///
546+ /// TODO write docs.
547+ /// TODO highly inefficient, need to understand metrics and performance testing.
548+ /// TODO Optimization: add rank / size into data structure
549+ pub async fn merge_groups (
550+ & self ,
551+ left_group_id : GroupId ,
552+ right_group_id : GroupId ,
553+ ) -> OptimizerResult < GroupId > {
554+ // Without a rank / size field, we have no way of determining which set is better to merge
555+ // into the other. So we will arbitrarily choose to merge the left group into the right
556+ // group here. If rank is added in the future, then merge the smaller set into the larger.
557+
558+ let left_root_id = self . get_root_group ( left_group_id) . await ?;
559+ let left_root = self . get_group ( left_root_id) . await ?;
560+ // A `None` next pointer means it should technically be pointing to itself.
561+ let left_next = left_root. next_id . unwrap_or ( left_root_id. 0 ) ;
562+ let mut active_left_root = left_root. into_active_model ( ) ;
563+
564+ let right_root_id = self . get_root_group ( right_group_id) . await ?;
565+ let right_root = self . get_group ( right_root_id) . await ?;
566+ // A `None` next pointer means it should technically be pointing to itself.
567+ let right_next = right_root. next_id . unwrap_or ( right_root_id. 0 ) ;
568+ let mut active_right_root = right_root. into_active_model ( ) ;
569+
570+ // Before we actually update the group records, We first need to generate new fingerprints
571+ // for every single expression that has a child group in the left set.
572+ // TODO make this more efficient, this code is doing double work from `get_group_set`.
573+ let group_set_ids = self . get_group_set ( left_group_id) . await ?;
574+ let mut left_group_models = Vec :: with_capacity ( group_set_ids. len ( ) ) ;
575+ for & group_id in & group_set_ids {
576+ left_group_models. push ( self . get_group ( group_id) . await ?) ;
577+ }
578+
579+ // Retrieve every single expression that has a child group in the left set.
580+ let left_group_expressions: Vec < Vec < logical_expression:: Model > > = left_group_models
581+ . load_many_to_many (
582+ logical_expression:: Entity ,
583+ logical_children:: Entity ,
584+ & self . db ,
585+ )
586+ . await ?;
587+
588+ // Need to replace every single occurrence of groups in the set with the new root.
589+ let rewrites: Vec < ( GroupId , GroupId ) > = group_set_ids
590+ . iter ( )
591+ . map ( |& group_id| ( group_id, right_root_id) )
592+ . collect ( ) ;
593+
594+ // For each expression, generate a new fingerprint.
595+ for model in left_group_expressions. into_iter ( ) . flatten ( ) {
596+ let expr_id = model. id ;
597+ let logical_expression: LogicalExpression = model. into ( ) ;
598+ let hash = logical_expression. fingerprint_with_rewrite ( & rewrites) ;
599+
600+ let fingerprint = fingerprint:: ActiveModel {
601+ id : NotSet ,
602+ logical_expression_id : Set ( expr_id) ,
603+ kind : Set ( logical_expression. kind ( ) ) ,
604+ hash : Set ( hash) ,
605+ } ;
606+ fingerprint:: Entity :: insert ( fingerprint)
607+ . exec ( & self . db )
608+ . await ?;
609+ }
610+
611+ // Update the left group root to point to the right group root.
612+ active_left_root. parent_id = Set ( Some ( right_root_id. 0 ) ) ;
613+
614+ // Swap the next pointers of each root to maintain the circular linked list.
615+ active_left_root. next_id = Set ( Some ( right_next) ) ;
616+ active_right_root. next_id = Set ( Some ( left_next) ) ;
617+
618+ active_left_root. update ( & self . db ) . await ?;
619+ active_right_root. update ( & self . db ) . await ?;
620+
621+ Ok ( right_root_id)
622+ }
522623}
0 commit comments