Skip to content

Commit b92047c

Browse files
committed
first draft fully merge group
1 parent 194ae5e commit b92047c

File tree

2 files changed

+129
-39
lines changed

2 files changed

+129
-39
lines changed

optd-mvp/src/memo/persistent/implementation.rs

Lines changed: 128 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,55 @@ impl PersistentMemo {
9090
// For every group along the path that we walked, set their parent id pointer to the root.
9191
// This allows for an amortized O(1) cost for `get_root_group`.
9292
for group in path {
93-
self.update_group_parent(GroupId(group.id), root_id).await?;
93+
let mut active_group = group.into_active_model();
94+
95+
// Update the group to point to the new parent.
96+
active_group.parent_id = Set(Some(root_id.0));
97+
active_group.update(&self.db).await?;
9498
}
9599

96100
Ok(root_id)
97101
}
98102

103+
/// Retrieves every group ID of groups that share the same root group with the input group.
104+
///
105+
/// If a group does not exist in the cycle, returns a [`MemoError::UnknownGroup`] error.
106+
///
107+
/// The group records form a union-find data structure that also maintains a circular linked
108+
/// list in every set that allows us to iterate over all elements in a set in linear time.
109+
pub async fn get_group_set(&self, group_id: GroupId) -> OptimizerResult<Vec<GroupId>> {
110+
// Iterate over the circular linked list until we reach ourselves again.
111+
let base_group = self.get_group(group_id).await?;
112+
113+
// The only case when `next_id` is set to `None` is if the current group is a root, which
114+
// means that this group is the only group in the set.
115+
if base_group.next_id.is_none() {
116+
assert!(base_group.parent_id.is_none());
117+
return Ok(vec![group_id]);
118+
}
119+
120+
// Iterate over the circular linked list until we see ourselves again, collecting nodes
121+
// along the way.
122+
let mut set = vec![];
123+
let mut next_id = base_group
124+
.next_id
125+
.expect("next pointer cannot be null if it is in a cycle");
126+
loop {
127+
let curr_group = self.get_group(GroupId(next_id)).await?;
128+
129+
if curr_group.id == group_id.0 {
130+
break;
131+
}
132+
133+
set.push(GroupId(curr_group.id));
134+
next_id = curr_group
135+
.next_id
136+
.expect("next pointer cannot be null if it is in a cycle");
137+
}
138+
139+
Ok(set)
140+
}
141+
99142
/// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`].
100143
///
101144
/// If the physical expression does not exist, returns a
@@ -227,30 +270,6 @@ impl PersistentMemo {
227270
Ok(old_id)
228271
}
229272

230-
/// Updates / replaces a group's parent group. Optionally returns the previous parent.
231-
///
232-
/// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error.
233-
pub async fn update_group_parent(
234-
&self,
235-
group_id: GroupId,
236-
parent_id: GroupId,
237-
) -> OptimizerResult<Option<GroupId>> {
238-
// First retrieve the group record.
239-
let mut group = self.get_group(group_id).await?.into_active_model();
240-
241-
// Check that the parent group exists.
242-
let _ = self.get_group(parent_id).await?;
243-
244-
// Update the group to point to the new parent.
245-
let old_parent = group.parent_id;
246-
group.parent_id = Set(Some(parent_id.0));
247-
group.update(&self.db).await?;
248-
249-
// Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
250-
let old_parent = old_parent.unwrap().map(GroupId);
251-
Ok(old_parent)
252-
}
253-
254273
/// Adds a logical expression to an existing group via its ID.
255274
///
256275
/// The caller is required to pass in a slice of [`GroupId`] that represent the child groups of
@@ -323,7 +342,7 @@ impl PersistentMemo {
323342
kind: Set(kind),
324343
hash: Set(hash),
325344
};
326-
let _ = fingerprint::Entity::insert(fingerprint)
345+
fingerprint::Entity::insert(fingerprint)
327346
.exec(&self.db)
328347
.await?;
329348

@@ -513,10 +532,92 @@ impl PersistentMemo {
513532
kind: Set(kind),
514533
hash: Set(hash),
515534
};
516-
let _ = fingerprint::Entity::insert(fingerprint)
535+
fingerprint::Entity::insert(fingerprint)
517536
.exec(&self.db)
518537
.await?;
519538

520539
Ok(Ok((GroupId(group_id), LogicalExpressionId(expr_id))))
521540
}
541+
542+
/// Merges two groups sets together.
543+
///
544+
/// If either of the input groups do not exist, returns a [`MemoError::UnknownGroup`] error.
545+
///
546+
/// TODO write docs.
547+
/// TODO highly inefficient, need to understand metrics and performance testing.
548+
/// TODO Optimization: add rank / size into data structure
549+
pub async fn merge_groups(
550+
&self,
551+
left_group_id: GroupId,
552+
right_group_id: GroupId,
553+
) -> OptimizerResult<GroupId> {
554+
// Without a rank / size field, we have no way of determining which set is better to merge
555+
// into the other. So we will arbitrarily choose to merge the left group into the right
556+
// group here. If rank is added in the future, then merge the smaller set into the larger.
557+
558+
let left_root_id = self.get_root_group(left_group_id).await?;
559+
let left_root = self.get_group(left_root_id).await?;
560+
// A `None` next pointer means it should technically be pointing to itself.
561+
let left_next = left_root.next_id.unwrap_or(left_root_id.0);
562+
let mut active_left_root = left_root.into_active_model();
563+
564+
let right_root_id = self.get_root_group(right_group_id).await?;
565+
let right_root = self.get_group(right_root_id).await?;
566+
// A `None` next pointer means it should technically be pointing to itself.
567+
let right_next = right_root.next_id.unwrap_or(right_root_id.0);
568+
let mut active_right_root = right_root.into_active_model();
569+
570+
// Before we actually update the group records, We first need to generate new fingerprints
571+
// for every single expression that has a child group in the left set.
572+
// TODO make this more efficient, this code is doing double work from `get_group_set`.
573+
let group_set_ids = self.get_group_set(left_group_id).await?;
574+
let mut left_group_models = Vec::with_capacity(group_set_ids.len());
575+
for &group_id in &group_set_ids {
576+
left_group_models.push(self.get_group(group_id).await?);
577+
}
578+
579+
// Retrieve every single expression that has a child group in the left set.
580+
let left_group_expressions: Vec<Vec<logical_expression::Model>> = left_group_models
581+
.load_many_to_many(
582+
logical_expression::Entity,
583+
logical_children::Entity,
584+
&self.db,
585+
)
586+
.await?;
587+
588+
// Need to replace every single occurrence of groups in the set with the new root.
589+
let rewrites: Vec<(GroupId, GroupId)> = group_set_ids
590+
.iter()
591+
.map(|&group_id| (group_id, right_root_id))
592+
.collect();
593+
594+
// For each expression, generate a new fingerprint.
595+
for model in left_group_expressions.into_iter().flatten() {
596+
let expr_id = model.id;
597+
let logical_expression: LogicalExpression = model.into();
598+
let hash = logical_expression.fingerprint_with_rewrite(&rewrites);
599+
600+
let fingerprint = fingerprint::ActiveModel {
601+
id: NotSet,
602+
logical_expression_id: Set(expr_id),
603+
kind: Set(logical_expression.kind()),
604+
hash: Set(hash),
605+
};
606+
fingerprint::Entity::insert(fingerprint)
607+
.exec(&self.db)
608+
.await?;
609+
}
610+
611+
// Update the left group root to point to the right group root.
612+
active_left_root.parent_id = Set(Some(right_root_id.0));
613+
614+
// Swap the next pointers of each root to maintain the circular linked list.
615+
active_left_root.next_id = Set(Some(right_next));
616+
active_right_root.next_id = Set(Some(left_next));
617+
618+
active_left_root.update(&self.db).await?;
619+
active_right_root.update(&self.db).await?;
620+
621+
Ok(right_root_id)
622+
}
522623
}

optd-mvp/src/memo/persistent/tests.rs

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -191,22 +191,11 @@ async fn test_simple_group_link() {
191191

192192
// The above tells the application that the expression already exists in the memo, specifically
193193
// under `existing_group`. Thus, we should link these two groups together.
194-
// Here, we arbitrarily choose to link group 1 into group 2.
195-
memo.update_group_parent(join_group_1, join_group_2)
196-
.await
197-
.unwrap();
194+
memo.merge_groups(join_group_1, join_group_2).await.unwrap();
198195

199196
let test_root_1 = memo.get_root_group(join_group_1).await.unwrap();
200197
let test_root_2 = memo.get_root_group(join_group_2).await.unwrap();
201198
assert_eq!(test_root_1, test_root_2);
202199

203-
// TODO(Connor)
204-
//
205-
// We now need to find all logical expressions that had group 1 (or whatever the root group of
206-
// the set that group 1 belongs to is, in this case it is just group 1) as a child, and add a
207-
// new fingerprint for each one that uses group 2 as a child instead.
208-
//
209-
// In order to do this, we need to iterate through every group in group 1's set.
210-
211200
memo.cleanup().await;
212201
}

0 commit comments

Comments
 (0)