Skip to content

Commit 4d58221

Browse files
committed
implement transaction on all memo table operations
1 parent 2c015b4 commit 4d58221

File tree

3 files changed

+187
-123
lines changed

3 files changed

+187
-123
lines changed

optd-mvp/src/memo/persistent/implementation.rs

Lines changed: 73 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,16 @@
66
77
#![allow(dead_code)]
88

9-
use super::PersistentMemo;
9+
use super::{PersistentMemo, PersistentMemoTransaction};
1010
use crate::{
1111
entities::*,
1212
expression::{LogicalExpression, PhysicalExpression},
1313
memo::{GroupId, GroupStatus, LogicalExpressionId, MemoError, PhysicalExpressionId},
1414
OptimizerResult, DATABASE_URL,
1515
};
1616
use sea_orm::{
17-
entity::prelude::*,
18-
entity::{IntoActiveModel, NotSet, Set},
19-
Database,
17+
entity::{prelude::*, IntoActiveModel, NotSet, Set},
18+
Database, DatabaseTransaction, TransactionTrait,
2019
};
2120
use std::{collections::HashSet, marker::PhantomData};
2221

@@ -39,6 +38,15 @@ where
3938
}
4039
}
4140

41+
/// Starts a new database transaction.
42+
///
43+
/// # Errors
44+
///
45+
/// Returns a [`DbErr`] if unable to create a new transaction.
46+
pub async fn begin(&self) -> OptimizerResult<PersistentMemoTransaction<L, P>> {
47+
Ok(PersistentMemoTransaction::new(self.db.begin().await?).await)
48+
}
49+
4250
/// Deletes all objects in the backing database.
4351
///
4452
/// Since there is no asynchronous drop yet in Rust, in order to drop all objects in the
@@ -71,6 +79,39 @@ where
7179
physical_children
7280
};
7381
}
82+
}
83+
84+
impl<L, P> PersistentMemoTransaction<L, P>
85+
where
86+
L: LogicalExpression,
87+
P: PhysicalExpression,
88+
{
89+
/// Creates a new transaction object.
90+
pub async fn new(txn: DatabaseTransaction) -> Self {
91+
Self {
92+
txn,
93+
_phantom_logical: PhantomData,
94+
_phantom_physical: PhantomData,
95+
}
96+
}
97+
98+
/// Commits the transaction.
99+
///
100+
/// # Errors
101+
///
102+
/// Returns a [`DbErr`] if unable to commit the transaction.
103+
pub async fn commit(self) -> OptimizerResult<()> {
104+
Ok(self.txn.commit().await?)
105+
}
106+
107+
/// Rolls back the transaction.
108+
///
109+
/// # Errors
110+
///
111+
/// Returns a [`DbErr`] if unable to roll back the transaction.
112+
pub async fn rollback(self) -> OptimizerResult<()> {
113+
Ok(self.txn.rollback().await?)
114+
}
74115

75116
/// Retrieves a [`group::Model`] given its ID.
76117
///
@@ -81,7 +122,7 @@ where
81122
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
82123
pub async fn get_group(&self, group_id: GroupId) -> OptimizerResult<group::Model> {
83124
Ok(group::Entity::find_by_id(group_id.0)
84-
.one(&self.db)
125+
.one(&self.txn)
85126
.await?
86127
.ok_or(MemoError::UnknownGroup(group_id))?)
87128
}
@@ -117,7 +158,7 @@ where
117158

118159
// Update the group to point to the new parent.
119160
active_group.parent_id = Set(Some(root_id.0));
120-
active_group.update(&self.db).await?;
161+
active_group.update(&self.txn).await?;
121162

122163
Ok(GroupId(root_id.0))
123164
}
@@ -180,7 +221,7 @@ where
180221
) -> OptimizerResult<(GroupId, P)> {
181222
// Lookup the entity in the database via the unique expression ID.
182223
let model = physical_expression::Entity::find_by_id(physical_expression_id.0)
183-
.one(&self.db)
224+
.one(&self.txn)
184225
.await?
185226
.ok_or(MemoError::UnknownPhysicalExpression(physical_expression_id))?;
186227

@@ -202,7 +243,7 @@ where
202243
) -> OptimizerResult<(GroupId, L)> {
203244
// Lookup the entity in the database via the unique expression ID.
204245
let model = logical_expression::Entity::find_by_id(logical_expression_id.0)
205-
.one(&self.db)
246+
.one(&self.txn)
206247
.await?
207248
.ok_or(MemoError::UnknownLogicalExpression(logical_expression_id))?;
208249

@@ -230,7 +271,7 @@ where
230271
// Search for expressions that have the given parent group ID.
231272
let children = logical_expression::Entity::find()
232273
.filter(logical_expression::Column::GroupId.eq(group_id.0))
233-
.all(&self.db)
274+
.all(&self.txn)
234275
.await?
235276
.into_iter()
236277
.map(|m| LogicalExpressionId(m.id))
@@ -257,7 +298,7 @@ where
257298
// Search for expressions that have the given parent group ID.
258299
let children = physical_expression::Entity::find()
259300
.filter(physical_expression::Column::GroupId.eq(group_id.0))
260-
.all(&self.db)
301+
.all(&self.txn)
261302
.await?
262303
.into_iter()
263304
.map(|m| PhysicalExpressionId(m.id))
@@ -273,7 +314,7 @@ where
273314
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. Can also return a
274315
/// [`DbErr`] if the update fails.
275316
pub async fn update_group_status(
276-
&self,
317+
&mut self,
277318
group_id: GroupId,
278319
status: GroupStatus,
279320
) -> OptimizerResult<GroupStatus> {
@@ -283,7 +324,7 @@ where
283324
// Update the group's status.
284325
let old_status = group.status;
285326
group.status = Set(status as u8 as i8);
286-
group.update(&self.db).await?;
327+
group.update(&self.txn).await?;
287328

288329
let old_status = match old_status.unwrap() {
289330
0 => GroupStatus::InProgress,
@@ -306,7 +347,7 @@ where
306347
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. Can also return a
307348
/// [`DbErr`] if the update fails.
308349
pub async fn update_group_winner(
309-
&self,
350+
&mut self,
310351
group_id: GroupId,
311352
physical_expression_id: PhysicalExpressionId,
312353
) -> OptimizerResult<Option<PhysicalExpressionId>> {
@@ -316,7 +357,7 @@ where
316357
// Update the group to point to the new winner.
317358
let old_id = group.winner;
318359
group.winner = Set(Some(physical_expression_id.0));
319-
group.update(&self.db).await?;
360+
group.update(&self.txn).await?;
320361

321362
// Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
322363
let old_id = old_id.unwrap().map(PhysicalExpressionId);
@@ -345,7 +386,7 @@ where
345386
/// is used for notifying the caller if the expression that they attempted to insert was a
346387
/// duplicate expression or not.
347388
pub async fn add_logical_expression_to_group(
348-
&self,
389+
&mut self,
349390
group_id: GroupId,
350391
logical_expression: L,
351392
children: &[GroupId],
@@ -366,7 +407,7 @@ where
366407
let mut active_model = model.into_active_model();
367408
active_model.group_id = Set(group_id.0);
368409
active_model.id = NotSet;
369-
let new_model = active_model.insert(&self.db).await?;
410+
let new_model = active_model.insert(&self.txn).await?;
370411

371412
let expr_id = new_model.id;
372413

@@ -378,7 +419,7 @@ where
378419
}
379420
}))
380421
.on_empty_do_nothing()
381-
.exec(&self.db)
422+
.exec(&self.txn)
382423
.await?;
383424

384425
// Finally, insert the fingerprint of the logical expression as well.
@@ -401,7 +442,7 @@ where
401442
hash: Set(hash),
402443
};
403444
fingerprint::Entity::insert(fingerprint)
404-
.exec(&self.db)
445+
.exec(&self.txn)
405446
.await?;
406447

407448
Ok(Ok(LogicalExpressionId(expr_id)))
@@ -420,7 +461,7 @@ where
420461
/// insertion of the new physical expression or any of its child junction entries are not able
421462
/// to be inserted.
422463
pub async fn add_physical_expression_to_group(
423-
&self,
464+
&mut self,
424465
group_id: GroupId,
425466
physical_expression: P,
426467
children: &[GroupId],
@@ -433,7 +474,7 @@ where
433474
let mut active_model = model.into_active_model();
434475
active_model.group_id = Set(group_id.0);
435476
active_model.id = NotSet;
436-
let new_model = active_model.insert(&self.db).await?;
477+
let new_model = active_model.insert(&self.txn).await?;
437478

438479
// Insert the child groups of the expression into the junction / children table.
439480
physical_children::Entity::insert_many(children.iter().copied().map(|child_id| {
@@ -443,7 +484,7 @@ where
443484
}
444485
}))
445486
.on_empty_do_nothing()
446-
.exec(&self.db)
487+
.exec(&self.txn)
447488
.await?;
448489

449490
Ok(PhysicalExpressionId(new_model.id))
@@ -490,7 +531,7 @@ where
490531
let potential_matches = fingerprint::Entity::find()
491532
.filter(fingerprint::Column::Hash.eq(fingerprint))
492533
.filter(fingerprint::Column::Kind.eq(kind))
493-
.all(&self.db)
534+
.all(&self.txn)
494535
.await?;
495536

496537
if potential_matches.is_empty() {
@@ -549,7 +590,7 @@ where
549590
/// is used for notifying the caller if the expression/group that they attempted to insert was a
550591
/// duplicate expression or not.
551592
pub async fn add_group(
552-
&self,
593+
&mut self,
553594
logical_expression: L,
554595
children: &[GroupId],
555596
) -> OptimizerResult<Result<(GroupId, LogicalExpressionId), (GroupId, LogicalExpressionId)>>
@@ -569,15 +610,15 @@ where
569610
};
570611

571612
// Create the new group.
572-
let group_res = group::Entity::insert(group).exec(&self.db).await?;
613+
let group_res = group::Entity::insert(group).exec(&self.txn).await?;
573614
let group_id = group_res.last_insert_id;
574615

575616
// Insert the input expression into the newly created group.
576617
let expression: logical_expression::Model = logical_expression.clone().into();
577618
let mut active_expression = expression.into_active_model();
578619
active_expression.group_id = Set(group_id);
579620
active_expression.id = NotSet;
580-
let new_expression = active_expression.insert(&self.db).await?;
621+
let new_expression = active_expression.insert(&self.txn).await?;
581622

582623
let group_id = new_expression.group_id;
583624
let expr_id = new_expression.id;
@@ -590,7 +631,7 @@ where
590631
}
591632
}))
592633
.on_empty_do_nothing()
593-
.exec(&self.db)
634+
.exec(&self.txn)
594635
.await?;
595636

596637
// Finally, insert the fingerprint of the logical expression as well.
@@ -613,7 +654,7 @@ where
613654
hash: Set(hash),
614655
};
615656
fingerprint::Entity::insert(fingerprint)
616-
.exec(&self.db)
657+
.exec(&self.txn)
617658
.await?;
618659

619660
Ok(Ok((GroupId(group_id), LogicalExpressionId(expr_id))))
@@ -631,7 +672,7 @@ where
631672
///
632673
/// TODO
633674
pub async fn merge_groups(
634-
&self,
675+
&mut self,
635676
left_group_id: GroupId,
636677
right_group_id: GroupId,
637678
) -> OptimizerResult<GroupId> {
@@ -665,7 +706,7 @@ where
665706
.load_many_to_many(
666707
logical_expression::Entity,
667708
logical_children::Entity,
668-
&self.db,
709+
&self.txn,
669710
)
670711
.await?;
671712

@@ -697,7 +738,7 @@ where
697738
hash: Set(hash),
698739
};
699740
fingerprint::Entity::insert(fingerprint)
700-
.exec(&self.db)
741+
.exec(&self.txn)
701742
.await?;
702743
}
703744

@@ -708,8 +749,8 @@ where
708749
active_left_root.next_id = Set(Some(right_next));
709750
active_right_root.next_id = Set(Some(left_next));
710751

711-
active_left_root.update(&self.db).await?;
712-
active_right_root.update(&self.db).await?;
752+
active_left_root.update(&self.txn).await?;
753+
active_right_root.update(&self.txn).await?;
713754

714755
Ok(right_root_id)
715756
}

optd-mvp/src/memo/persistent/mod.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! This module contains the definition and implementation of the [`PersistentMemo`] type, which
22
//! implements the `Memo` trait and supports memo table operations necessary for query optimization.
33
4-
use sea_orm::DatabaseConnection;
4+
use sea_orm::{DatabaseConnection, DatabaseTransaction};
55
use std::marker::PhantomData;
66

77
#[cfg(test)]
@@ -22,4 +22,16 @@ pub struct PersistentMemo<L, P> {
2222
_phantom_physical: PhantomData<P>,
2323
}
2424

25+
/// TODO docs.
26+
pub struct PersistentMemoTransaction<L, P> {
27+
/// A database transaction over the [`PersistentMemo`] table.
28+
txn: DatabaseTransaction,
29+
30+
/// Generic marker for a generic logical expression.
31+
_phantom_logical: PhantomData<L>,
32+
33+
/// Generic marker for a generic physical expression.
34+
_phantom_physical: PhantomData<P>,
35+
}
36+
2537
pub mod implementation;

0 commit comments

Comments
 (0)