66
77#![ allow( dead_code) ]
88
9- use super :: PersistentMemo ;
9+ use super :: { PersistentMemo , PersistentMemoTransaction } ;
1010use crate :: {
1111 entities:: * ,
1212 expression:: { LogicalExpression , PhysicalExpression } ,
1313 memo:: { GroupId , GroupStatus , LogicalExpressionId , MemoError , PhysicalExpressionId } ,
1414 OptimizerResult , DATABASE_URL ,
1515} ;
1616use sea_orm:: {
17- entity:: prelude:: * ,
18- entity:: { IntoActiveModel , NotSet , Set } ,
19- Database ,
17+ entity:: { prelude:: * , IntoActiveModel , NotSet , Set } ,
18+ Database , DatabaseTransaction , TransactionTrait ,
2019} ;
2120use std:: { collections:: HashSet , marker:: PhantomData } ;
2221
3938 }
4039 }
4140
41+ /// Starts a new database transaction.
42+ ///
43+ /// # Errors
44+ ///
45+ /// Returns a [`DbErr`] if unable to create a new transaction.
46+ pub async fn begin ( & self ) -> OptimizerResult < PersistentMemoTransaction < L , P > > {
47+ Ok ( PersistentMemoTransaction :: new ( self . db . begin ( ) . await ?) . await )
48+ }
49+
4250 /// Deletes all objects in the backing database.
4351 ///
4452 /// Since there is no asynchronous drop yet in Rust, in order to drop all objects in the
7179 physical_children
7280 } ;
7381 }
82+ }
83+
84+ impl < L , P > PersistentMemoTransaction < L , P >
85+ where
86+ L : LogicalExpression ,
87+ P : PhysicalExpression ,
88+ {
89+ /// Creates a new transaction object.
90+ pub async fn new ( txn : DatabaseTransaction ) -> Self {
91+ Self {
92+ txn,
93+ _phantom_logical : PhantomData ,
94+ _phantom_physical : PhantomData ,
95+ }
96+ }
97+
98+ /// Commits the transaction.
99+ ///
100+ /// # Errors
101+ ///
102+ /// Returns a [`DbErr`] if unable to commit the transaction.
103+ pub async fn commit ( self ) -> OptimizerResult < ( ) > {
104+ Ok ( self . txn . commit ( ) . await ?)
105+ }
106+
107+ /// Rolls back the transaction.
108+ ///
109+ /// # Errors
110+ ///
111+ /// Returns a [`DbErr`] if unable to roll back the transaction.
112+ pub async fn rollback ( self ) -> OptimizerResult < ( ) > {
113+ Ok ( self . txn . rollback ( ) . await ?)
114+ }
74115
75116 /// Retrieves a [`group::Model`] given its ID.
76117 ///
81122 /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
82123 pub async fn get_group ( & self , group_id : GroupId ) -> OptimizerResult < group:: Model > {
83124 Ok ( group:: Entity :: find_by_id ( group_id. 0 )
84- . one ( & self . db )
125+ . one ( & self . txn )
85126 . await ?
86127 . ok_or ( MemoError :: UnknownGroup ( group_id) ) ?)
87128 }
@@ -117,7 +158,7 @@ where
117158
118159 // Update the group to point to the new parent.
119160 active_group. parent_id = Set ( Some ( root_id. 0 ) ) ;
120- active_group. update ( & self . db ) . await ?;
161+ active_group. update ( & self . txn ) . await ?;
121162
122163 Ok ( GroupId ( root_id. 0 ) )
123164 }
@@ -180,7 +221,7 @@ where
180221 ) -> OptimizerResult < ( GroupId , P ) > {
181222 // Lookup the entity in the database via the unique expression ID.
182223 let model = physical_expression:: Entity :: find_by_id ( physical_expression_id. 0 )
183- . one ( & self . db )
224+ . one ( & self . txn )
184225 . await ?
185226 . ok_or ( MemoError :: UnknownPhysicalExpression ( physical_expression_id) ) ?;
186227
@@ -202,7 +243,7 @@ where
202243 ) -> OptimizerResult < ( GroupId , L ) > {
203244 // Lookup the entity in the database via the unique expression ID.
204245 let model = logical_expression:: Entity :: find_by_id ( logical_expression_id. 0 )
205- . one ( & self . db )
246+ . one ( & self . txn )
206247 . await ?
207248 . ok_or ( MemoError :: UnknownLogicalExpression ( logical_expression_id) ) ?;
208249
@@ -230,7 +271,7 @@ where
230271 // Search for expressions that have the given parent group ID.
231272 let children = logical_expression:: Entity :: find ( )
232273 . filter ( logical_expression:: Column :: GroupId . eq ( group_id. 0 ) )
233- . all ( & self . db )
274+ . all ( & self . txn )
234275 . await ?
235276 . into_iter ( )
236277 . map ( |m| LogicalExpressionId ( m. id ) )
@@ -257,7 +298,7 @@ where
257298 // Search for expressions that have the given parent group ID.
258299 let children = physical_expression:: Entity :: find ( )
259300 . filter ( physical_expression:: Column :: GroupId . eq ( group_id. 0 ) )
260- . all ( & self . db )
301+ . all ( & self . txn )
261302 . await ?
262303 . into_iter ( )
263304 . map ( |m| PhysicalExpressionId ( m. id ) )
@@ -273,7 +314,7 @@ where
273314 /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. Can also return a
274315 /// [`DbErr`] if the update fails.
275316 pub async fn update_group_status (
276- & self ,
317+ & mut self ,
277318 group_id : GroupId ,
278319 status : GroupStatus ,
279320 ) -> OptimizerResult < GroupStatus > {
@@ -283,7 +324,7 @@ where
283324 // Update the group's status.
284325 let old_status = group. status ;
285326 group. status = Set ( status as u8 as i8 ) ;
286- group. update ( & self . db ) . await ?;
327+ group. update ( & self . txn ) . await ?;
287328
288329 let old_status = match old_status. unwrap ( ) {
289330 0 => GroupStatus :: InProgress ,
@@ -306,7 +347,7 @@ where
306347 /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. Can also return a
307348 /// [`DbErr`] if the update fails.
308349 pub async fn update_group_winner (
309- & self ,
350+ & mut self ,
310351 group_id : GroupId ,
311352 physical_expression_id : PhysicalExpressionId ,
312353 ) -> OptimizerResult < Option < PhysicalExpressionId > > {
@@ -316,7 +357,7 @@ where
316357 // Update the group to point to the new winner.
317358 let old_id = group. winner ;
318359 group. winner = Set ( Some ( physical_expression_id. 0 ) ) ;
319- group. update ( & self . db ) . await ?;
360+ group. update ( & self . txn ) . await ?;
320361
321362 // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`.
322363 let old_id = old_id. unwrap ( ) . map ( PhysicalExpressionId ) ;
@@ -345,7 +386,7 @@ where
345386 /// is used for notifying the caller if the expression that they attempted to insert was a
346387 /// duplicate expression or not.
347388 pub async fn add_logical_expression_to_group (
348- & self ,
389+ & mut self ,
349390 group_id : GroupId ,
350391 logical_expression : L ,
351392 children : & [ GroupId ] ,
@@ -366,7 +407,7 @@ where
366407 let mut active_model = model. into_active_model ( ) ;
367408 active_model. group_id = Set ( group_id. 0 ) ;
368409 active_model. id = NotSet ;
369- let new_model = active_model. insert ( & self . db ) . await ?;
410+ let new_model = active_model. insert ( & self . txn ) . await ?;
370411
371412 let expr_id = new_model. id ;
372413
@@ -378,7 +419,7 @@ where
378419 }
379420 } ) )
380421 . on_empty_do_nothing ( )
381- . exec ( & self . db )
422+ . exec ( & self . txn )
382423 . await ?;
383424
384425 // Finally, insert the fingerprint of the logical expression as well.
@@ -401,7 +442,7 @@ where
401442 hash : Set ( hash) ,
402443 } ;
403444 fingerprint:: Entity :: insert ( fingerprint)
404- . exec ( & self . db )
445+ . exec ( & self . txn )
405446 . await ?;
406447
407448 Ok ( Ok ( LogicalExpressionId ( expr_id) ) )
@@ -420,7 +461,7 @@ where
420461 /// insertion of the new physical expression or any of its child junction entries are not able
421462 /// to be inserted.
422463 pub async fn add_physical_expression_to_group (
423- & self ,
464+ & mut self ,
424465 group_id : GroupId ,
425466 physical_expression : P ,
426467 children : & [ GroupId ] ,
@@ -433,7 +474,7 @@ where
433474 let mut active_model = model. into_active_model ( ) ;
434475 active_model. group_id = Set ( group_id. 0 ) ;
435476 active_model. id = NotSet ;
436- let new_model = active_model. insert ( & self . db ) . await ?;
477+ let new_model = active_model. insert ( & self . txn ) . await ?;
437478
438479 // Insert the child groups of the expression into the junction / children table.
439480 physical_children:: Entity :: insert_many ( children. iter ( ) . copied ( ) . map ( |child_id| {
@@ -443,7 +484,7 @@ where
443484 }
444485 } ) )
445486 . on_empty_do_nothing ( )
446- . exec ( & self . db )
487+ . exec ( & self . txn )
447488 . await ?;
448489
449490 Ok ( PhysicalExpressionId ( new_model. id ) )
@@ -490,7 +531,7 @@ where
490531 let potential_matches = fingerprint:: Entity :: find ( )
491532 . filter ( fingerprint:: Column :: Hash . eq ( fingerprint) )
492533 . filter ( fingerprint:: Column :: Kind . eq ( kind) )
493- . all ( & self . db )
534+ . all ( & self . txn )
494535 . await ?;
495536
496537 if potential_matches. is_empty ( ) {
@@ -549,7 +590,7 @@ where
549590 /// is used for notifying the caller if the expression/group that they attempted to insert was a
550591 /// duplicate expression or not.
551592 pub async fn add_group (
552- & self ,
593+ & mut self ,
553594 logical_expression : L ,
554595 children : & [ GroupId ] ,
555596 ) -> OptimizerResult < Result < ( GroupId , LogicalExpressionId ) , ( GroupId , LogicalExpressionId ) > >
@@ -569,15 +610,15 @@ where
569610 } ;
570611
571612 // Create the new group.
572- let group_res = group:: Entity :: insert ( group) . exec ( & self . db ) . await ?;
613+ let group_res = group:: Entity :: insert ( group) . exec ( & self . txn ) . await ?;
573614 let group_id = group_res. last_insert_id ;
574615
575616 // Insert the input expression into the newly created group.
576617 let expression: logical_expression:: Model = logical_expression. clone ( ) . into ( ) ;
577618 let mut active_expression = expression. into_active_model ( ) ;
578619 active_expression. group_id = Set ( group_id) ;
579620 active_expression. id = NotSet ;
580- let new_expression = active_expression. insert ( & self . db ) . await ?;
621+ let new_expression = active_expression. insert ( & self . txn ) . await ?;
581622
582623 let group_id = new_expression. group_id ;
583624 let expr_id = new_expression. id ;
@@ -590,7 +631,7 @@ where
590631 }
591632 } ) )
592633 . on_empty_do_nothing ( )
593- . exec ( & self . db )
634+ . exec ( & self . txn )
594635 . await ?;
595636
596637 // Finally, insert the fingerprint of the logical expression as well.
@@ -613,7 +654,7 @@ where
613654 hash : Set ( hash) ,
614655 } ;
615656 fingerprint:: Entity :: insert ( fingerprint)
616- . exec ( & self . db )
657+ . exec ( & self . txn )
617658 . await ?;
618659
619660 Ok ( Ok ( ( GroupId ( group_id) , LogicalExpressionId ( expr_id) ) ) )
@@ -631,7 +672,7 @@ where
631672 ///
632673 /// TODO
633674 pub async fn merge_groups (
634- & self ,
675+ & mut self ,
635676 left_group_id : GroupId ,
636677 right_group_id : GroupId ,
637678 ) -> OptimizerResult < GroupId > {
@@ -665,7 +706,7 @@ where
665706 . load_many_to_many (
666707 logical_expression:: Entity ,
667708 logical_children:: Entity ,
668- & self . db ,
709+ & self . txn ,
669710 )
670711 . await ?;
671712
@@ -697,7 +738,7 @@ where
697738 hash : Set ( hash) ,
698739 } ;
699740 fingerprint:: Entity :: insert ( fingerprint)
700- . exec ( & self . db )
741+ . exec ( & self . txn )
701742 . await ?;
702743 }
703744
@@ -708,8 +749,8 @@ where
708749 active_left_root. next_id = Set ( Some ( right_next) ) ;
709750 active_right_root. next_id = Set ( Some ( left_next) ) ;
710751
711- active_left_root. update ( & self . db ) . await ?;
712- active_right_root. update ( & self . db ) . await ?;
752+ active_left_root. update ( & self . txn ) . await ?;
753+ active_right_root. update ( & self . txn ) . await ?;
713754
714755 Ok ( right_root_id)
715756 }
0 commit comments