Skip to content

Commit 9d9db84

Browse files
committed
add union find and group merging support
1 parent 0a0af6d commit 9d9db84

File tree

8 files changed

+210
-26
lines changed

8 files changed

+210
-26
lines changed

optd-mvp/DESIGN.md

Whitespace-only changes.
File renamed without changes.

optd-mvp/src/entities/cascades_group.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ use sea_orm::entity::prelude::*;
77
pub struct Model {
88
#[sea_orm(primary_key)]
99
pub id: i32,
10+
pub status: i8,
1011
pub winner: Option<i32>,
1112
pub cost: Option<i64>,
12-
pub is_optimized: bool,
1313
pub parent_id: Option<i32>,
1414
}
1515

optd-mvp/src/memo/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ pub struct LogicalExpressionId(pub i32);
1919
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2020
pub struct PhysicalExpressionId(pub i32);
2121

22+
/// A status enum representing the different states a group can be during query optimization.
23+
#[repr(u8)]
24+
pub enum GroupStatus {
25+
InProgress = 0,
26+
Explored = 1,
27+
Optimized = 2,
28+
}
29+
2230
/// The different kinds of errors that might occur while running operations on a memo table.
2331
#[derive(Error, Debug)]
2432
pub enum MemoError {

optd-mvp/src/memo/persistent/implementation.rs

Lines changed: 125 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use super::PersistentMemo;
1010
use crate::{
1111
entities::*,
1212
expression::{LogicalExpression, PhysicalExpression},
13-
memo::{GroupId, LogicalExpressionId, MemoError, PhysicalExpressionId},
13+
memo::{GroupId, GroupStatus, LogicalExpressionId, MemoError, PhysicalExpressionId},
1414
OptimizerResult, DATABASE_URL,
1515
};
1616
use sea_orm::{
@@ -66,6 +66,40 @@ impl PersistentMemo {
6666
.ok_or(MemoError::UnknownGroup(group_id))?)
6767
}
6868

69+
/// Retrieves the root / canonical group ID of the given group ID.
70+
///
71+
/// The groups form a union find / disjoint set parent pointer forest, where group merging
72+
/// causes two trees to merge.
73+
///
74+
/// This function uses the path compression optimization, which amortizes the cost to a single
75+
/// lookup (theoretically in constant time, but we must be wary of the I/O roundtrip).
76+
pub async fn get_root_group(&self, group_id: GroupId) -> OptimizerResult<GroupId> {
77+
let mut curr_group = self.get_group(group_id).await?;
78+
79+
// Traverse up the path and find the root group, keeping track of groups we have visited.
80+
let mut path = vec![];
81+
loop {
82+
let Some(parent_id) = curr_group.parent_id else {
83+
break;
84+
};
85+
86+
let next_group = self.get_group(GroupId(parent_id)).await?;
87+
path.push(curr_group);
88+
curr_group = next_group;
89+
}
90+
91+
let root_id = GroupId(curr_group.id);
92+
93+
// Path Compression Optimization:
94+
// For every group along the path that we walked, set their parent id pointer to the root.
95+
// This allows for an amortized O(1) cost for `get_root_group`.
96+
for group in path {
97+
self.update_group_parent(GroupId(group.id), root_id).await?;
98+
}
99+
100+
Ok(root_id)
101+
}
102+
69103
/// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`].
70104
///
71105
/// If the physical expression does not exist, returns a
@@ -146,6 +180,32 @@ impl PersistentMemo {
146180
Ok(children)
147181
}
148182

183+
/// Updates / replaces a group's status. Returns the previous group status.
184+
///
185+
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
186+
pub async fn update_group_status(
187+
&self,
188+
group_id: GroupId,
189+
status: GroupStatus,
190+
) -> OptimizerResult<GroupStatus> {
191+
// First retrieve the group record.
192+
let mut group = self.get_group(group_id).await?.into_active_model();
193+
194+
// Update the group's status.
195+
let old_status = group.status;
196+
group.status = Set(status as u8 as i8);
197+
group.update(&self.db).await?;
198+
199+
let old_status = match old_status.unwrap() {
200+
0 => GroupStatus::InProgress,
201+
1 => GroupStatus::Explored,
202+
2 => GroupStatus::Optimized,
203+
_ => panic!("encountered an invalid group status"),
204+
};
205+
206+
Ok(old_status)
207+
}
208+
149209
/// Updates / replaces a group's best physical plan (winner). Optionally returns the previous
150210
/// winner's physical expression ID.
151211
///
@@ -167,8 +227,32 @@ impl PersistentMemo {
167227
group.update(&self.db).await?;
168228

169229
// Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
170-
let old = old_id.unwrap().map(PhysicalExpressionId);
171-
Ok(old)
230+
let old_id = old_id.unwrap().map(PhysicalExpressionId);
231+
Ok(old_id)
232+
}
233+
234+
/// Updates / replaces a group's parent group. Optionally returns the previous parent.
235+
///
236+
/// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error.
237+
pub async fn update_group_parent(
238+
&self,
239+
group_id: GroupId,
240+
parent_id: GroupId,
241+
) -> OptimizerResult<Option<GroupId>> {
242+
// First retrieve the group record.
243+
let mut group = self.get_group(group_id).await?.into_active_model();
244+
245+
// Check that the parent group exists.
246+
let _ = self.get_group(parent_id).await?;
247+
248+
// Update the group to point to the new parent.
249+
let old_parent = group.parent_id;
250+
group.parent_id = Set(Some(parent_id.0));
251+
group.update(&self.db).await?;
252+
253+
// Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
254+
let old_parent = old_parent.unwrap().map(GroupId);
255+
Ok(old_parent)
172256
}
173257

174258
/// Adds a logical expression to an existing group via its ID.
@@ -192,10 +276,10 @@ impl PersistentMemo {
192276
group_id: GroupId,
193277
logical_expression: LogicalExpression,
194278
children: &[GroupId],
195-
) -> OptimizerResult<Result<LogicalExpressionId, LogicalExpressionId>> {
279+
) -> OptimizerResult<Result<LogicalExpressionId, (GroupId, LogicalExpressionId)>> {
196280
// Check if the expression already exists anywhere in the memo table.
197281
if let Some(existing_id) = self
198-
.is_duplicate_logical_expression(&logical_expression)
282+
.is_duplicate_logical_expression(&logical_expression, children)
199283
.await?
200284
{
201285
return Ok(Err(existing_id));
@@ -227,7 +311,15 @@ impl PersistentMemo {
227311
// Finally, insert the fingerprint of the logical expression as well.
228312
let new_expr: LogicalExpression = new_model.into();
229313
let kind = new_expr.kind();
230-
let hash = new_expr.fingerprint();
314+
315+
// In order to calculate a correct fingerprint, we will want to use the IDs of the root
316+
// groups of the children instead of the child ID themselves.
317+
let mut rewrites = vec![];
318+
for &child_id in children {
319+
let root_id = self.get_root_group(child_id).await?;
320+
rewrites.push((child_id, root_id));
321+
}
322+
let hash = new_expr.fingerprint_with_rewrite(&rewrites);
231323

232324
let fingerprint = fingerprint::ActiveModel {
233325
id: NotSet,
@@ -285,8 +377,8 @@ impl PersistentMemo {
285377
/// In order to prevent a large amount of duplicate work, the memo table must support duplicate
286378
/// expression detection.
287379
///
288-
/// Returns `Some(expression_id)` if the memo table detects that the expression already exists,
289-
/// and `None` otherwise.
380+
/// Returns `Some((group_id, expression_id))` if the memo table detects that the expression
381+
/// already exists, and `None` otherwise.
290382
///
291383
/// This function assumes that the child groups of the expression are currently roots of their
292384
/// group sets. For example, if G1 and G2 should be merged, and G1 is the root, then the input
@@ -296,13 +388,22 @@ impl PersistentMemo {
296388
pub async fn is_duplicate_logical_expression(
297389
&self,
298390
logical_expression: &LogicalExpression,
299-
) -> OptimizerResult<Option<LogicalExpressionId>> {
391+
children: &[GroupId],
392+
) -> OptimizerResult<Option<(GroupId, LogicalExpressionId)>> {
300393
let model: logical_expression::Model = logical_expression.clone().into();
301394

302395
// Lookup all expressions that have the same fingerprint and kind. There may be false
303396
// positives, but we will check for those next.
304397
let kind = model.kind;
305-
let fingerprint = logical_expression.fingerprint();
398+
399+
// In order to calculate a correct fingerprint, we will want to use the IDs of the root
400+
// groups of the children instead of the child ID themselves.
401+
let mut rewrites = vec![];
402+
for &child_id in children {
403+
let root_id = self.get_root_group(child_id).await?;
404+
rewrites.push((child_id, root_id));
405+
}
406+
let fingerprint = logical_expression.fingerprint_with_rewrite(&rewrites);
306407

307408
// Filter first by the fingerprint, and then the kind.
308409
// FIXME: The kind is already embedded into the fingerprint, so we may not actually need the
@@ -323,11 +424,11 @@ impl PersistentMemo {
323424
let mut match_id = None;
324425
for potential_match in potential_matches {
325426
let expr_id = LogicalExpressionId(potential_match.logical_expression_id);
326-
let (_, expr) = self.get_logical_expression(expr_id).await?;
427+
let (group_id, expr) = self.get_logical_expression(expr_id).await?;
327428

328429
// Check for an exact match.
329430
if &expr == logical_expression {
330-
match_id = Some(expr_id);
431+
match_id = Some((group_id, expr_id));
331432

332433
// There should be at most one duplicate expression, so we can break here.
333434
break;
@@ -359,18 +460,17 @@ impl PersistentMemo {
359460
) -> OptimizerResult<Result<(GroupId, LogicalExpressionId), (GroupId, LogicalExpressionId)>>
360461
{
361462
// Check if the expression already exists in the memo table.
362-
if let Some(existing_id) = self
363-
.is_duplicate_logical_expression(&logical_expression)
463+
if let Some((group_id, existing_id)) = self
464+
.is_duplicate_logical_expression(&logical_expression, children)
364465
.await?
365466
{
366-
let (group_id, _expr) = self.get_logical_expression(existing_id).await?;
367467
return Ok(Err((group_id, existing_id)));
368468
}
369469

370470
// The expression does not exist yet, so we need to create a new group and new expression.
371471
let group = cascades_group::ActiveModel {
372472
winner: Set(None),
373-
is_optimized: Set(false),
473+
status: Set(0), // `GroupStatus::InProgress` status.
374474
..Default::default()
375475
};
376476

@@ -401,7 +501,15 @@ impl PersistentMemo {
401501
// Finally, insert the fingerprint of the logical expression as well.
402502
let new_expr: LogicalExpression = new_model.into();
403503
let kind = new_expr.kind();
404-
let hash = new_expr.fingerprint();
504+
505+
// In order to calculate a correct fingerprint, we will want to use the IDs of the root
506+
// groups of the children instead of the child ID themselves.
507+
let mut rewrites = vec![];
508+
for &child_id in children {
509+
let root_id = self.get_root_group(child_id).await?;
510+
rewrites.push((child_id, root_id));
511+
}
512+
let hash = new_expr.fingerprint_with_rewrite(&rewrites);
405513

406514
let fingerprint = fingerprint::ActiveModel {
407515
id: NotSet,

optd-mvp/src/memo/persistent/tests.rs

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,22 @@ async fn test_simple_logical_duplicates() {
3434
// Test `add_logical_expression_to_group`.
3535
{
3636
// Attempting to add a duplicate expression into the same group should also fail every time.
37-
let logical_expression_id_2a = memo
37+
let (group_id_2a, logical_expression_id_2a) = memo
3838
.add_logical_expression_to_group(group_id, scan2a, &[])
3939
.await
4040
.unwrap()
4141
.err()
4242
.unwrap();
43+
assert_eq!(group_id, group_id_2a);
4344
assert_eq!(logical_expression_id, logical_expression_id_2a);
4445

45-
let logical_expression_id_2b = memo
46+
let (group_id_2b, logical_expression_id_2b) = memo
4647
.add_logical_expression_to_group(group_id, scan2b, &[])
4748
.await
4849
.unwrap()
4950
.err()
5051
.unwrap();
52+
assert_eq!(group_id, group_id_2b);
5153
assert_eq!(logical_expression_id, logical_expression_id_2b);
5254
}
5355

@@ -140,3 +142,69 @@ async fn test_simple_tree() {
140142

141143
memo.cleanup().await;
142144
}
145+
146+
/// Tests basic group merging. See comments in the test itself for more information.
147+
#[ignore]
148+
#[tokio::test]
149+
async fn test_simple_group_link() {
150+
let memo = PersistentMemo::new().await;
151+
memo.cleanup().await;
152+
153+
// Create two scan groups.
154+
let scan1 = scan("t1".to_string());
155+
let scan2 = scan("t2".to_string());
156+
let (scan_id_1, _) = memo.add_group(scan1, &[]).await.unwrap().ok().unwrap();
157+
let (scan_id_2, _) = memo.add_group(scan2, &[]).await.unwrap().ok().unwrap();
158+
159+
// Create two join expression that should be in the same group.
160+
// Even though these are obviously the same expression (to humans), the fingerprints will be
161+
// different, and so they will be put into different groups.
162+
let join1 = join(scan_id_1, scan_id_2, "t1.a = t2.b".to_string());
163+
let join2 = join(scan_id_2, scan_id_1, "t2.b = t1.a".to_string());
164+
let join_unknown = join2.clone();
165+
166+
let (join_group_1, _) = memo
167+
.add_group(join1, &[scan_id_1, scan_id_2])
168+
.await
169+
.unwrap()
170+
.ok()
171+
.unwrap();
172+
let (join_group_2, join_expr_2) = memo
173+
.add_group(join2, &[scan_id_2, scan_id_1])
174+
.await
175+
.unwrap()
176+
.ok()
177+
.unwrap();
178+
assert_ne!(join_group_1, join_group_2);
179+
180+
// Assume that some rule was applied to `join1`, and it outputs something like `join_unknown`.
181+
// The memo table will tell us that `join_unknown == join2`.
182+
// Take note here that `join_unknown` is a clone of `join2`, not `join1`.
183+
let (existing_group, not_actually_new_expr_id) = memo
184+
.add_logical_expression_to_group(join_group_1, join_unknown, &[scan_id_2, scan_id_1])
185+
.await
186+
.unwrap()
187+
.err()
188+
.unwrap();
189+
assert_eq!(existing_group, join_group_2);
190+
assert_eq!(not_actually_new_expr_id, join_expr_2);
191+
192+
// The above tells the application that the expression already exists in the memo, specifically
193+
// under `existing_group`. Thus, we should link these two groups together.
194+
// Here, we arbitrarily choose to link group 1 into group 2.
195+
memo.update_group_parent(join_group_1, join_group_2).await.unwrap();
196+
197+
let test_root_1 = memo.get_root_group(join_group_1).await.unwrap();
198+
let test_root_2 = memo.get_root_group(join_group_2).await.unwrap();
199+
assert_eq!(test_root_1, test_root_2);
200+
201+
// TODO(Connor)
202+
//
203+
// We now need to find all logical expressions that had group 1 (or whatever the root group of
204+
// the set that group 1 belongs to is, in this case it is just group 1) as a child, and add a
205+
// new fingerprint for each one that uses group 2 as a child instead.
206+
//
207+
// In order to do this, we need to iterate through every group in group 1's set.
208+
209+
memo.cleanup().await;
210+
}

optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ use sea_orm_migration::{prelude::*, schema::*};
7474
pub enum CascadesGroup {
7575
Table,
7676
Id,
77+
Status,
7778
Winner,
7879
Cost,
79-
IsOptimized,
8080
ParentId,
8181
}
8282

@@ -92,16 +92,16 @@ impl MigrationTrait for Migration {
9292
.table(CascadesGroup::Table)
9393
.if_not_exists()
9494
.col(pk_auto(CascadesGroup::Id))
95+
.col(tiny_integer(CascadesGroup::Status))
9596
.col(integer_null(CascadesGroup::Winner))
96-
.col(big_unsigned_null(CascadesGroup::Cost))
97+
.col(big_integer_null(CascadesGroup::Cost))
9798
.foreign_key(
9899
ForeignKey::create()
99100
.from(CascadesGroup::Table, CascadesGroup::Winner)
100101
.to(PhysicalExpression::Table, PhysicalExpression::Id)
101102
.on_delete(ForeignKeyAction::SetNull)
102103
.on_update(ForeignKeyAction::Cascade),
103104
)
104-
.col(boolean(CascadesGroup::IsOptimized))
105105
.col(integer_null(CascadesGroup::ParentId))
106106
.foreign_key(
107107
ForeignKey::create()

0 commit comments

Comments
 (0)