7
7
//!
8
8
//! This module is responsible for writing new records to MAS' database.
9
9
10
- use std:: { fmt:: Display , net:: IpAddr } ;
10
+ use std:: {
11
+ fmt:: Display ,
12
+ net:: IpAddr ,
13
+ sync:: {
14
+ Arc ,
15
+ atomic:: { AtomicU32 , Ordering } ,
16
+ } ,
17
+ } ;
11
18
12
19
use chrono:: { DateTime , Utc } ;
13
20
use futures_util:: { FutureExt , TryStreamExt , future:: BoxFuture } ;
@@ -44,6 +51,9 @@ pub enum Error {
44
51
#[ error( "inconsistent database: {0}" ) ]
45
52
Inconsistent ( String ) ,
46
53
54
+ #[ error( "bug in syn2mas: write buffers not finished" ) ]
55
+ WriteBuffersNotFinished ,
56
+
47
57
#[ error( "{0}" ) ]
48
58
Multiple ( MultipleErrors ) ,
49
59
}
@@ -188,12 +198,52 @@ impl WriterConnectionPool {
188
198
}
189
199
}
190
200
201
+ /// Small utility to make sure `finish()` is called on all write buffers
202
+ /// before committing to the database.
203
+ #[ derive( Default ) ]
204
+ struct FinishChecker {
205
+ counter : Arc < AtomicU32 > ,
206
+ }
207
+
208
+ struct FinishCheckerHandle {
209
+ counter : Arc < AtomicU32 > ,
210
+ }
211
+
212
+ impl FinishChecker {
213
+ /// Acquire a new handle, for a task that should declare when it has
214
+ /// finished.
215
+ pub fn handle ( & self ) -> FinishCheckerHandle {
216
+ self . counter . fetch_add ( 1 , Ordering :: SeqCst ) ;
217
+ FinishCheckerHandle {
218
+ counter : Arc :: clone ( & self . counter ) ,
219
+ }
220
+ }
221
+
222
+ /// Check that all handles have been declared as finished.
223
+ pub fn check_all_finished ( self ) -> Result < ( ) , Error > {
224
+ if self . counter . load ( Ordering :: SeqCst ) == 0 {
225
+ Ok ( ( ) )
226
+ } else {
227
+ Err ( Error :: WriteBuffersNotFinished )
228
+ }
229
+ }
230
+ }
231
+
232
+ impl FinishCheckerHandle {
233
+ /// Declare that the task this handle represents has been finished.
234
+ pub fn declare_finished ( self ) {
235
+ self . counter . fetch_sub ( 1 , Ordering :: SeqCst ) ;
236
+ }
237
+ }
238
+
191
239
pub struct MasWriter {
192
240
conn : LockedMasDatabase ,
193
241
writer_pool : WriterConnectionPool ,
194
242
195
243
indices_to_restore : Vec < IndexDescription > ,
196
244
constraints_to_restore : Vec < ConstraintDescription > ,
245
+
246
+ write_buffer_finish_checker : FinishChecker ,
197
247
}
198
248
199
249
pub struct MasNewUser {
@@ -453,6 +503,7 @@ impl MasWriter {
453
503
writer_pool : WriterConnectionPool :: new ( writer_connections) ,
454
504
indices_to_restore,
455
505
constraints_to_restore,
506
+ write_buffer_finish_checker : FinishChecker :: default ( ) ,
456
507
} )
457
508
}
458
509
@@ -520,6 +571,8 @@ impl MasWriter {
520
571
/// - If the database connection experiences an error.
521
572
#[ tracing:: instrument( skip_all) ]
522
573
pub async fn finish ( mut self ) -> Result < PgConnection , Error > {
574
+ self . write_buffer_finish_checker . check_all_finished ( ) ?;
575
+
523
576
// Commit all writer transactions to the database.
524
577
self . writer_pool
525
578
. finish ( )
@@ -1033,28 +1086,24 @@ type WriteBufferFlusher<T> =
1033
1086
1034
1087
/// A buffer for writing rows to the MAS database.
1035
1088
/// Generic over the type of rows.
1036
- ///
1037
- /// # Panics
1038
- ///
1039
- /// Panics if dropped before `finish()` has been called.
1040
1089
pub struct MasWriteBuffer < T > {
1041
1090
rows : Vec < T > ,
1042
1091
flusher : WriteBufferFlusher < T > ,
1043
- finished : bool ,
1092
+ finish_checker_handle : FinishCheckerHandle ,
1044
1093
}
1045
1094
1046
1095
impl < T > MasWriteBuffer < T > {
1047
- pub fn new ( flusher : WriteBufferFlusher < T > ) -> Self {
1096
+ pub fn new ( writer : & MasWriter , flusher : WriteBufferFlusher < T > ) -> Self {
1048
1097
MasWriteBuffer {
1049
1098
rows : Vec :: with_capacity ( WRITE_BUFFER_BATCH_SIZE ) ,
1050
1099
flusher,
1051
- finished : false ,
1100
+ finish_checker_handle : writer . write_buffer_finish_checker . handle ( ) ,
1052
1101
}
1053
1102
}
1054
1103
1055
1104
pub async fn finish ( mut self , writer : & mut MasWriter ) -> Result < ( ) , Error > {
1056
- self . finished = true ;
1057
1105
self . flush ( writer) . await ?;
1106
+ self . finish_checker_handle . declare_finished ( ) ;
1058
1107
Ok ( ( ) )
1059
1108
}
1060
1109
@@ -1077,12 +1126,6 @@ impl<T> MasWriteBuffer<T> {
1077
1126
}
1078
1127
}
1079
1128
1080
- impl < T > Drop for MasWriteBuffer < T > {
1081
- fn drop ( & mut self ) {
1082
- assert ! ( self . finished, "MasWriteBuffer dropped but not finished!" ) ;
1083
- }
1084
- }
1085
-
1086
1129
#[ cfg( test) ]
1087
1130
mod test {
1088
1131
use std:: collections:: { BTreeMap , BTreeSet } ;
0 commit comments