Skip to content

Commit 4b06b81

Browse files
committed
add union find and group merging support
1 parent 0a0af6d commit 4b06b81

File tree

5 files changed

+146
-16
lines changed

5 files changed

+146
-16
lines changed

optd-mvp/src/entities/cascades_group.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ use sea_orm::entity::prelude::*;
77
pub struct Model {
88
#[sea_orm(primary_key)]
99
pub id: i32,
10+
pub status: i8,
1011
pub winner: Option<i32>,
1112
pub cost: Option<i64>,
12-
pub is_optimized: bool,
1313
pub parent_id: Option<i32>,
1414
}
1515

optd-mvp/src/memo/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ pub struct LogicalExpressionId(pub i32);
1919
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2020
pub struct PhysicalExpressionId(pub i32);
2121

22+
/// A status enum representing the different states a group can be during query optimization.
23+
#[repr(u8)]
24+
pub enum GroupStatus {
25+
InProgress = 0,
26+
Explored = 1,
27+
Optimized = 2,
28+
}
29+
2230
/// The different kinds of errors that might occur while running operations on a memo table.
2331
#[derive(Error, Debug)]
2432
pub enum MemoError {

optd-mvp/src/memo/persistent/implementation.rs

Lines changed: 131 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use super::PersistentMemo;
1010
use crate::{
1111
entities::*,
1212
expression::{LogicalExpression, PhysicalExpression},
13-
memo::{GroupId, LogicalExpressionId, MemoError, PhysicalExpressionId},
13+
memo::{GroupId, GroupStatus, LogicalExpressionId, MemoError, PhysicalExpressionId},
1414
OptimizerResult, DATABASE_URL,
1515
};
1616
use sea_orm::{
@@ -66,6 +66,40 @@ impl PersistentMemo {
6666
.ok_or(MemoError::UnknownGroup(group_id))?)
6767
}
6868

69+
/// Retrieves the root / canonical group ID of the given group ID.
70+
///
71+
/// The groups form a union find / disjoint set parent pointer forest, where group merging
72+
/// causes two trees to merge.
73+
///
74+
/// This function uses the path compression optimization, which amortizes the cost to a single
75+
/// lookup (theoretically in constant time, but we must be wary of the I/O roundtrip).
76+
pub async fn get_root_group(&self, group_id: GroupId) -> OptimizerResult<GroupId> {
77+
let mut curr_group = self.get_group(group_id).await?;
78+
79+
// Traverse up the path and find the root group, keeping track of groups we have visited.
80+
let mut path = vec![];
81+
loop {
82+
let Some(parent_id) = curr_group.parent_id else {
83+
break;
84+
};
85+
86+
let next_group = self.get_group(GroupId(parent_id)).await?;
87+
path.push(curr_group);
88+
curr_group = next_group;
89+
}
90+
91+
let root_id = GroupId(curr_group.id);
92+
93+
// Path Compression Optimization:
94+
// For every group along the path that we walked, set their parent id pointer to the root.
95+
// This allows for an amortized O(1) cost for `get_root_group`.
96+
for group in path {
97+
self.update_group_parent(GroupId(group.id), root_id).await?;
98+
}
99+
100+
Ok(root_id)
101+
}
102+
69103
/// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`].
70104
///
71105
/// If the physical expression does not exist, returns a
@@ -146,6 +180,32 @@ impl PersistentMemo {
146180
Ok(children)
147181
}
148182

183+
/// Updates / replaces a group's status. Returns the previous group status.
184+
///
185+
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
186+
pub async fn update_group_status(
187+
&self,
188+
group_id: GroupId,
189+
status: GroupStatus,
190+
) -> OptimizerResult<GroupStatus> {
191+
// First retrieve the group record.
192+
let mut group = self.get_group(group_id).await?.into_active_model();
193+
194+
// Update the group's status.
195+
let old_status = group.status;
196+
group.status = Set(status as u8 as i8);
197+
group.update(&self.db).await?;
198+
199+
let old_status = match old_status.unwrap() {
200+
0 => GroupStatus::InProgress,
201+
1 => GroupStatus::Explored,
202+
2 => GroupStatus::Optimized,
203+
_ => panic!("encountered an invalid group status"),
204+
};
205+
206+
Ok(old_status)
207+
}
208+
149209
/// Updates / replaces a group's best physical plan (winner). Optionally returns the previous
150210
/// winner's physical expression ID.
151211
///
@@ -167,8 +227,45 @@ impl PersistentMemo {
167227
group.update(&self.db).await?;
168228

169229
// Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
170-
let old = old_id.unwrap().map(PhysicalExpressionId);
171-
Ok(old)
230+
let old_id = old_id.unwrap().map(PhysicalExpressionId);
231+
Ok(old_id)
232+
}
233+
234+
/// Updates / replaces a group's parent group. Optionally returns the previous parent.
235+
///
236+
/// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error.
237+
pub async fn update_group_parent(
238+
&self,
239+
group_id: GroupId,
240+
parent_id: GroupId,
241+
) -> OptimizerResult<Option<GroupId>> {
242+
// First retrieve the group record.
243+
let mut group = self.get_group(group_id).await?.into_active_model();
244+
245+
// Check that the parent group exists.
246+
let _ = self.get_group(parent_id).await?;
247+
248+
// Update the group to point to the new parent.
249+
let old_parent = group.parent_id;
250+
group.parent_id = Set(Some(parent_id.0));
251+
group.update(&self.db).await?;
252+
253+
// Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
254+
let old_parent = old_parent.unwrap().map(GroupId);
255+
Ok(old_parent)
256+
}
257+
258+
/// Merges two groups sets together. Returns the new root group of the unioned sets.
259+
///
260+
/// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error.
261+
///
262+
/// TODO use union by rank / size as an optimization?
263+
pub async fn merge_groups(&self, group1: GroupId, group2: GroupId) -> OptimizerResult<GroupId> {
264+
// Without tracking the size of each of the groups, it is arbitrary which group is better to
265+
// merge into the other. So we will arbitrarily choose `group1` to merge into `group2`.
266+
self.update_group_parent(group1, group2).await?;
267+
268+
Ok(group2)
172269
}
173270

174271
/// Adds a logical expression to an existing group via its ID.
@@ -195,7 +292,7 @@ impl PersistentMemo {
195292
) -> OptimizerResult<Result<LogicalExpressionId, LogicalExpressionId>> {
196293
// Check if the expression already exists anywhere in the memo table.
197294
if let Some(existing_id) = self
198-
.is_duplicate_logical_expression(&logical_expression)
295+
.is_duplicate_logical_expression(&logical_expression, children)
199296
.await?
200297
{
201298
return Ok(Err(existing_id));
@@ -227,7 +324,15 @@ impl PersistentMemo {
227324
// Finally, insert the fingerprint of the logical expression as well.
228325
let new_expr: LogicalExpression = new_model.into();
229326
let kind = new_expr.kind();
230-
let hash = new_expr.fingerprint();
327+
328+
// In order to calculate a correct fingerprint, we will want to use the IDs of the root
329+
// groups of the children instead of the child ID themselves.
330+
let mut rewrites = vec![];
331+
for &child_id in children {
332+
let root_id = self.get_root_group(child_id).await?;
333+
rewrites.push((child_id, root_id));
334+
}
335+
let hash = new_expr.fingerprint_with_rewrite(&rewrites);
231336

232337
let fingerprint = fingerprint::ActiveModel {
233338
id: NotSet,
@@ -296,13 +401,22 @@ impl PersistentMemo {
296401
pub async fn is_duplicate_logical_expression(
297402
&self,
298403
logical_expression: &LogicalExpression,
404+
children: &[GroupId],
299405
) -> OptimizerResult<Option<LogicalExpressionId>> {
300406
let model: logical_expression::Model = logical_expression.clone().into();
301407

302408
// Lookup all expressions that have the same fingerprint and kind. There may be false
303409
// positives, but we will check for those next.
304410
let kind = model.kind;
305-
let fingerprint = logical_expression.fingerprint();
411+
412+
// In order to calculate a correct fingerprint, we will want to use the IDs of the root
413+
// groups of the children instead of the child ID themselves.
414+
let mut rewrites = vec![];
415+
for &child_id in children {
416+
let root_id = self.get_root_group(child_id).await?;
417+
rewrites.push((child_id, root_id));
418+
}
419+
let fingerprint = logical_expression.fingerprint_with_rewrite(&rewrites);
306420

307421
// Filter first by the fingerprint, and then the kind.
308422
// FIXME: The kind is already embedded into the fingerprint, so we may not actually need the
@@ -360,7 +474,7 @@ impl PersistentMemo {
360474
{
361475
// Check if the expression already exists in the memo table.
362476
if let Some(existing_id) = self
363-
.is_duplicate_logical_expression(&logical_expression)
477+
.is_duplicate_logical_expression(&logical_expression, children)
364478
.await?
365479
{
366480
let (group_id, _expr) = self.get_logical_expression(existing_id).await?;
@@ -370,7 +484,7 @@ impl PersistentMemo {
370484
// The expression does not exist yet, so we need to create a new group and new expression.
371485
let group = cascades_group::ActiveModel {
372486
winner: Set(None),
373-
is_optimized: Set(false),
487+
status: Set(0), // `GroupStatus::InProgress` status.
374488
..Default::default()
375489
};
376490

@@ -401,7 +515,15 @@ impl PersistentMemo {
401515
// Finally, insert the fingerprint of the logical expression as well.
402516
let new_expr: LogicalExpression = new_model.into();
403517
let kind = new_expr.kind();
404-
let hash = new_expr.fingerprint();
518+
519+
// In order to calculate a correct fingerprint, we will want to use the IDs of the root
520+
// groups of the children instead of the child ID themselves.
521+
let mut rewrites = vec![];
522+
for &child_id in children {
523+
let root_id = self.get_root_group(child_id).await?;
524+
rewrites.push((child_id, root_id));
525+
}
526+
let hash = new_expr.fingerprint_with_rewrite(&rewrites);
405527

406528
let fingerprint = fingerprint::ActiveModel {
407529
id: NotSet,

optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ use sea_orm_migration::{prelude::*, schema::*};
7474
pub enum CascadesGroup {
7575
Table,
7676
Id,
77+
Status,
7778
Winner,
7879
Cost,
79-
IsOptimized,
8080
ParentId,
8181
}
8282

@@ -92,16 +92,16 @@ impl MigrationTrait for Migration {
9292
.table(CascadesGroup::Table)
9393
.if_not_exists()
9494
.col(pk_auto(CascadesGroup::Id))
95+
.col(tiny_integer(CascadesGroup::Status))
9596
.col(integer_null(CascadesGroup::Winner))
96-
.col(big_unsigned_null(CascadesGroup::Cost))
97+
.col(big_integer_null(CascadesGroup::Cost))
9798
.foreign_key(
9899
ForeignKey::create()
99100
.from(CascadesGroup::Table, CascadesGroup::Winner)
100101
.to(PhysicalExpression::Table, PhysicalExpression::Id)
101102
.on_delete(ForeignKeyAction::SetNull)
102103
.on_update(ForeignKeyAction::Cascade),
103104
)
104-
.col(boolean(CascadesGroup::IsOptimized))
105105
.col(integer_null(CascadesGroup::ParentId))
106106
.foreign_key(
107107
ForeignKey::create()

optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,16 @@ impl MigrationTrait for Migration {
2626
.table(Fingerprint::Table)
2727
.if_not_exists()
2828
.col(pk_auto(Fingerprint::Id))
29-
.col(unsigned(Fingerprint::LogicalExpressionId))
29+
.col(integer(Fingerprint::LogicalExpressionId))
3030
.foreign_key(
3131
ForeignKey::create()
3232
.from(Fingerprint::Table, Fingerprint::LogicalExpressionId)
3333
.to(LogicalExpression::Table, LogicalExpression::Id)
3434
.on_delete(ForeignKeyAction::Cascade)
3535
.on_update(ForeignKeyAction::Cascade),
3636
)
37-
.col(small_unsigned(Fingerprint::Kind))
38-
.col(big_unsigned(Fingerprint::Hash))
37+
.col(small_integer(Fingerprint::Kind))
38+
.col(big_integer(Fingerprint::Hash))
3939
.to_owned(),
4040
)
4141
.await

0 commit comments

Comments
 (0)