7
7
//!
8
8
//! This module is responsible for writing new records to MAS' database.
9
9
10
- use std:: { fmt:: Display , net:: IpAddr } ;
10
+ use std:: {
11
+ fmt:: Display ,
12
+ net:: IpAddr ,
13
+ sync:: {
14
+ atomic:: { AtomicU32 , Ordering } ,
15
+ Arc ,
16
+ } ,
17
+ } ;
11
18
12
19
use chrono:: { DateTime , Utc } ;
13
20
use futures_util:: { future:: BoxFuture , FutureExt , TryStreamExt } ;
@@ -44,6 +51,9 @@ pub enum Error {
44
51
#[ error( "inconsistent database: {0}" ) ]
45
52
Inconsistent ( String ) ,
46
53
54
+ #[ error( "bug in syn2mas: write buffers not finished" ) ]
55
+ WriteBuffersNotFinished ,
56
+
47
57
#[ error( "{0}" ) ]
48
58
Multiple ( MultipleErrors ) ,
49
59
}
@@ -185,12 +195,52 @@ impl WriterConnectionPool {
185
195
}
186
196
}
187
197
198
+ /// Small utility to make sure `finish()` is called on all write buffers
199
+ /// before committing to the database.
200
+ #[ derive( Default ) ]
201
+ struct FinishChecker {
202
+ counter : Arc < AtomicU32 > ,
203
+ }
204
+
205
+ struct FinishCheckerHandle {
206
+ counter : Arc < AtomicU32 > ,
207
+ }
208
+
209
+ impl FinishChecker {
210
+ /// Acquire a new handle, for a task that should declare when it has
211
+ /// finished.
212
+ pub fn handle ( & self ) -> FinishCheckerHandle {
213
+ self . counter . fetch_add ( 1 , Ordering :: SeqCst ) ;
214
+ FinishCheckerHandle {
215
+ counter : Arc :: clone ( & self . counter ) ,
216
+ }
217
+ }
218
+
219
+ /// Check that all handles have been declared as finished.
220
+ pub fn check_all_finished ( self ) -> Result < ( ) , Error > {
221
+ if self . counter . load ( Ordering :: SeqCst ) == 0 {
222
+ Ok ( ( ) )
223
+ } else {
224
+ Err ( Error :: WriteBuffersNotFinished )
225
+ }
226
+ }
227
+ }
228
+
229
+ impl FinishCheckerHandle {
230
+ /// Declare that the task this handle represents has been finished.
231
+ pub fn declare_finished ( self ) {
232
+ self . counter . fetch_sub ( 1 , Ordering :: SeqCst ) ;
233
+ }
234
+ }
235
+
188
236
pub struct MasWriter {
189
237
conn : LockedMasDatabase ,
190
238
writer_pool : WriterConnectionPool ,
191
239
192
240
indices_to_restore : Vec < IndexDescription > ,
193
241
constraints_to_restore : Vec < ConstraintDescription > ,
242
+
243
+ write_buffer_finish_checker : FinishChecker ,
194
244
}
195
245
196
246
pub struct MasNewUser {
@@ -450,6 +500,7 @@ impl MasWriter {
450
500
writer_pool : WriterConnectionPool :: new ( writer_connections) ,
451
501
indices_to_restore,
452
502
constraints_to_restore,
503
+ write_buffer_finish_checker : FinishChecker :: default ( ) ,
453
504
} )
454
505
}
455
506
@@ -517,6 +568,8 @@ impl MasWriter {
517
568
/// - If the database connection experiences an error.
518
569
#[ tracing:: instrument( skip_all) ]
519
570
pub async fn finish ( mut self ) -> Result < PgConnection , Error > {
571
+ self . write_buffer_finish_checker . check_all_finished ( ) ?;
572
+
520
573
// Commit all writer transactions to the database.
521
574
self . writer_pool
522
575
. finish ( )
@@ -1030,28 +1083,24 @@ type WriteBufferFlusher<T> =
1030
1083
1031
1084
/// A buffer for writing rows to the MAS database.
1032
1085
/// Generic over the type of rows.
1033
- ///
1034
- /// # Panics
1035
- ///
1036
- /// Panics if dropped before `finish()` has been called.
1037
1086
pub struct MasWriteBuffer < T > {
1038
1087
rows : Vec < T > ,
1039
1088
flusher : WriteBufferFlusher < T > ,
1040
- finished : bool ,
1089
+ finish_checker_handle : FinishCheckerHandle ,
1041
1090
}
1042
1091
1043
1092
impl < T > MasWriteBuffer < T > {
1044
- pub fn new ( flusher : WriteBufferFlusher < T > ) -> Self {
1093
+ pub fn new ( writer : & MasWriter , flusher : WriteBufferFlusher < T > ) -> Self {
1045
1094
MasWriteBuffer {
1046
1095
rows : Vec :: with_capacity ( WRITE_BUFFER_BATCH_SIZE ) ,
1047
1096
flusher,
1048
- finished : false ,
1097
+ finish_checker_handle : writer . write_buffer_finish_checker . handle ( ) ,
1049
1098
}
1050
1099
}
1051
1100
1052
1101
pub async fn finish ( mut self , writer : & mut MasWriter ) -> Result < ( ) , Error > {
1053
- self . finished = true ;
1054
1102
self . flush ( writer) . await ?;
1103
+ self . finish_checker_handle . declare_finished ( ) ;
1055
1104
Ok ( ( ) )
1056
1105
}
1057
1106
@@ -1074,12 +1123,6 @@ impl<T> MasWriteBuffer<T> {
1074
1123
}
1075
1124
}
1076
1125
1077
- impl < T > Drop for MasWriteBuffer < T > {
1078
- fn drop ( & mut self ) {
1079
- assert ! ( self . finished, "MasWriteBuffer dropped but not finished!" ) ;
1080
- }
1081
- }
1082
-
1083
1126
#[ cfg( test) ]
1084
1127
mod test {
1085
1128
use std:: collections:: { BTreeMap , BTreeSet } ;
0 commit comments