77//!
88//! This module is responsible for writing new records to MAS' database.
99
10- use std:: { fmt:: Display , net:: IpAddr } ;
10+ use std:: {
11+ fmt:: Display ,
12+ net:: IpAddr ,
13+ sync:: {
14+ Arc ,
15+ atomic:: { AtomicU32 , Ordering } ,
16+ } ,
17+ } ;
1118
1219use chrono:: { DateTime , Utc } ;
1320use futures_util:: { FutureExt , TryStreamExt , future:: BoxFuture } ;
@@ -44,6 +51,9 @@ pub enum Error {
4451 #[ error( "inconsistent database: {0}" ) ]
4552 Inconsistent ( String ) ,
4653
54+ #[ error( "bug in syn2mas: write buffers not finished" ) ]
55+ WriteBuffersNotFinished ,
56+
4757 #[ error( "{0}" ) ]
4858 Multiple ( MultipleErrors ) ,
4959}
@@ -188,12 +198,52 @@ impl WriterConnectionPool {
188198 }
189199}
190200
201+ /// Small utility to make sure `finish()` is called on all write buffers
202+ /// before committing to the database.
203+ #[ derive( Default ) ]
204+ struct FinishChecker {
205+ counter : Arc < AtomicU32 > ,
206+ }
207+
208+ struct FinishCheckerHandle {
209+ counter : Arc < AtomicU32 > ,
210+ }
211+
212+ impl FinishChecker {
213+ /// Acquire a new handle, for a task that should declare when it has
214+ /// finished.
215+ pub fn handle ( & self ) -> FinishCheckerHandle {
216+ self . counter . fetch_add ( 1 , Ordering :: SeqCst ) ;
217+ FinishCheckerHandle {
218+ counter : Arc :: clone ( & self . counter ) ,
219+ }
220+ }
221+
222+ /// Check that all handles have been declared as finished.
223+ pub fn check_all_finished ( self ) -> Result < ( ) , Error > {
224+ if self . counter . load ( Ordering :: SeqCst ) == 0 {
225+ Ok ( ( ) )
226+ } else {
227+ Err ( Error :: WriteBuffersNotFinished )
228+ }
229+ }
230+ }
231+
232+ impl FinishCheckerHandle {
233+ /// Declare that the task this handle represents has been finished.
234+ pub fn declare_finished ( self ) {
235+ self . counter . fetch_sub ( 1 , Ordering :: SeqCst ) ;
236+ }
237+ }
238+
191239pub struct MasWriter {
192240 conn : LockedMasDatabase ,
193241 writer_pool : WriterConnectionPool ,
194242
195243 indices_to_restore : Vec < IndexDescription > ,
196244 constraints_to_restore : Vec < ConstraintDescription > ,
245+
246+ write_buffer_finish_checker : FinishChecker ,
197247}
198248
199249pub struct MasNewUser {
@@ -453,6 +503,7 @@ impl MasWriter {
453503 writer_pool : WriterConnectionPool :: new ( writer_connections) ,
454504 indices_to_restore,
455505 constraints_to_restore,
506+ write_buffer_finish_checker : FinishChecker :: default ( ) ,
456507 } )
457508 }
458509
@@ -520,6 +571,8 @@ impl MasWriter {
520571 /// - If the database connection experiences an error.
521572 #[ tracing:: instrument( skip_all) ]
522573 pub async fn finish ( mut self ) -> Result < PgConnection , Error > {
574+ self . write_buffer_finish_checker . check_all_finished ( ) ?;
575+
523576 // Commit all writer transactions to the database.
524577 self . writer_pool
525578 . finish ( )
@@ -1033,28 +1086,24 @@ type WriteBufferFlusher<T> =
10331086
10341087/// A buffer for writing rows to the MAS database.
10351088/// Generic over the type of rows.
1036- ///
1037- /// # Panics
1038- ///
1039- /// Panics if dropped before `finish()` has been called.
10401089pub struct MasWriteBuffer < T > {
10411090 rows : Vec < T > ,
10421091 flusher : WriteBufferFlusher < T > ,
1043- finished : bool ,
1092+ finish_checker_handle : FinishCheckerHandle ,
10441093}
10451094
10461095impl < T > MasWriteBuffer < T > {
1047- pub fn new ( flusher : WriteBufferFlusher < T > ) -> Self {
1096+ pub fn new ( writer : & MasWriter , flusher : WriteBufferFlusher < T > ) -> Self {
10481097 MasWriteBuffer {
10491098 rows : Vec :: with_capacity ( WRITE_BUFFER_BATCH_SIZE ) ,
10501099 flusher,
1051- finished : false ,
1100+ finish_checker_handle : writer . write_buffer_finish_checker . handle ( ) ,
10521101 }
10531102 }
10541103
10551104 pub async fn finish ( mut self , writer : & mut MasWriter ) -> Result < ( ) , Error > {
1056- self . finished = true ;
10571105 self . flush ( writer) . await ?;
1106+ self . finish_checker_handle . declare_finished ( ) ;
10581107 Ok ( ( ) )
10591108 }
10601109
@@ -1077,12 +1126,6 @@ impl<T> MasWriteBuffer<T> {
10771126 }
10781127}
10791128
1080- impl < T > Drop for MasWriteBuffer < T > {
1081- fn drop ( & mut self ) {
1082- assert ! ( self . finished, "MasWriteBuffer dropped but not finished!" ) ;
1083- }
1084- }
1085-
10861129#[ cfg( test) ]
10871130mod test {
10881131 use std:: collections:: { BTreeMap , BTreeSet } ;
0 commit comments