77//!
88//! This module is responsible for writing new records to MAS' database.
99
10- use std:: { fmt:: Display , net:: IpAddr } ;
10+ use std:: {
11+ fmt:: Display ,
12+ net:: IpAddr ,
13+ sync:: {
14+ atomic:: { AtomicU32 , Ordering } ,
15+ Arc ,
16+ } ,
17+ } ;
1118
1219use chrono:: { DateTime , Utc } ;
1320use futures_util:: { future:: BoxFuture , FutureExt , TryStreamExt } ;
@@ -44,6 +51,9 @@ pub enum Error {
4451 #[ error( "inconsistent database: {0}" ) ]
4552 Inconsistent ( String ) ,
4653
54+ #[ error( "bug in syn2mas: write buffers not finished" ) ]
55+ WriteBuffersNotFinished ,
56+
4757 #[ error( "{0}" ) ]
4858 Multiple ( MultipleErrors ) ,
4959}
@@ -185,12 +195,52 @@ impl WriterConnectionPool {
185195 }
186196}
187197
198+ /// Small utility to make sure `finish()` is called on all write buffers
199+ /// before committing to the database.
200+ #[ derive( Default ) ]
201+ struct FinishChecker {
202+ counter : Arc < AtomicU32 > ,
203+ }
204+
205+ struct FinishCheckerHandle {
206+ counter : Arc < AtomicU32 > ,
207+ }
208+
209+ impl FinishChecker {
210+ /// Acquire a new handle, for a task that should declare when it has
211+ /// finished.
212+ pub fn handle ( & self ) -> FinishCheckerHandle {
213+ self . counter . fetch_add ( 1 , Ordering :: SeqCst ) ;
214+ FinishCheckerHandle {
215+ counter : Arc :: clone ( & self . counter ) ,
216+ }
217+ }
218+
219+ /// Check that all handles have been declared as finished.
220+ pub fn check_all_finished ( self ) -> Result < ( ) , Error > {
221+ if self . counter . load ( Ordering :: SeqCst ) == 0 {
222+ Ok ( ( ) )
223+ } else {
224+ Err ( Error :: WriteBuffersNotFinished )
225+ }
226+ }
227+ }
228+
229+ impl FinishCheckerHandle {
230+ /// Declare that the task this handle represents has been finished.
231+ pub fn declare_finished ( self ) {
232+ self . counter . fetch_sub ( 1 , Ordering :: SeqCst ) ;
233+ }
234+ }
235+
188236pub struct MasWriter < ' c > {
189237 conn : LockedMasDatabase < ' c > ,
190238 writer_pool : WriterConnectionPool ,
191239
192240 indices_to_restore : Vec < IndexDescription > ,
193241 constraints_to_restore : Vec < ConstraintDescription > ,
242+
243+ write_buffer_finish_checker : FinishChecker ,
194244}
195245
196246pub struct MasNewUser {
@@ -449,6 +499,7 @@ impl<'conn> MasWriter<'conn> {
449499 writer_pool : WriterConnectionPool :: new ( writer_connections) ,
450500 indices_to_restore,
451501 constraints_to_restore,
502+ write_buffer_finish_checker : FinishChecker :: default ( ) ,
452503 } )
453504 }
454505
@@ -515,6 +566,8 @@ impl<'conn> MasWriter<'conn> {
515566 /// - If the database connection experiences an error.
516567 #[ tracing:: instrument( skip_all) ]
517568 pub async fn finish ( mut self ) -> Result < ( ) , Error > {
569+ self . write_buffer_finish_checker . check_all_finished ( ) ?;
570+
518571 // Commit all writer transactions to the database.
519572 self . writer_pool
520573 . finish ( )
@@ -1027,28 +1080,24 @@ type WriteBufferFlusher<'conn, T> =
10271080
10281081/// A buffer for writing rows to the MAS database.
10291082/// Generic over the type of rows.
1030- ///
1031- /// # Panics
1032- ///
1033- /// Panics if dropped before `finish()` has been called.
10341083pub struct MasWriteBuffer < ' conn , T > {
10351084 rows : Vec < T > ,
10361085 flusher : WriteBufferFlusher < ' conn , T > ,
1037- finished : bool ,
1086+ finish_checker_handle : FinishCheckerHandle ,
10381087}
10391088
10401089impl < ' conn , T > MasWriteBuffer < ' conn , T > {
1041- pub fn new ( flusher : WriteBufferFlusher < ' conn , T > ) -> Self {
1090+ pub fn new ( writer : & MasWriter , flusher : WriteBufferFlusher < ' conn , T > ) -> Self {
10421091 MasWriteBuffer {
10431092 rows : Vec :: with_capacity ( WRITE_BUFFER_BATCH_SIZE ) ,
10441093 flusher,
1045- finished : false ,
1094+ finish_checker_handle : writer . write_buffer_finish_checker . handle ( ) ,
10461095 }
10471096 }
10481097
10491098 pub async fn finish ( mut self , writer : & mut MasWriter < ' conn > ) -> Result < ( ) , Error > {
1050- self . finished = true ;
10511099 self . flush ( writer) . await ?;
1100+ self . finish_checker_handle . declare_finished ( ) ;
10521101 Ok ( ( ) )
10531102 }
10541103
@@ -1071,12 +1120,6 @@ impl<'conn, T> MasWriteBuffer<'conn, T> {
10711120 }
10721121}
10731122
1074- impl < T > Drop for MasWriteBuffer < ' _ , T > {
1075- fn drop ( & mut self ) {
1076- assert ! ( self . finished, "MasWriteBuffer dropped but not finished!" ) ;
1077- }
1078- }
1079-
10801123#[ cfg( test) ]
10811124mod test {
10821125 use std:: collections:: { BTreeMap , BTreeSet } ;
0 commit comments