7
7
//!
8
8
//! This module is responsible for writing new records to MAS' database.
9
9
10
- use std:: { fmt:: Display , net:: IpAddr } ;
10
+ use std:: {
11
+ fmt:: Display ,
12
+ net:: IpAddr ,
13
+ sync:: {
14
+ Arc ,
15
+ atomic:: { AtomicU32 , Ordering } ,
16
+ } ,
17
+ } ;
11
18
12
19
use chrono:: { DateTime , Utc } ;
13
20
use futures_util:: { FutureExt , TryStreamExt , future:: BoxFuture } ;
14
21
use sqlx:: { Executor , PgConnection , query, query_as} ;
15
22
use thiserror:: Error ;
16
23
use thiserror_ext:: { Construct , ContextInto } ;
17
24
use tokio:: sync:: mpsc:: { self , Receiver , Sender } ;
18
- use tracing:: { Level , error, info, warn} ;
25
+ use tracing:: { Instrument , Level , error, info, warn} ;
19
26
use uuid:: { NonNilUuid , Uuid } ;
20
27
21
28
use self :: {
@@ -44,6 +51,9 @@ pub enum Error {
44
51
#[ error( "inconsistent database: {0}" ) ]
45
52
Inconsistent ( String ) ,
46
53
54
+ #[ error( "bug in syn2mas: write buffers not finished" ) ]
55
+ WriteBuffersNotFinished ,
56
+
47
57
#[ error( "{0}" ) ]
48
58
Multiple ( MultipleErrors ) ,
49
59
}
@@ -109,18 +119,21 @@ impl WriterConnectionPool {
109
119
match self . connection_rx . recv ( ) . await {
110
120
Some ( Ok ( mut connection) ) => {
111
121
let connection_tx = self . connection_tx . clone ( ) ;
112
- tokio:: task:: spawn ( async move {
113
- let to_return = match task ( & mut connection) . await {
114
- Ok ( ( ) ) => Ok ( connection) ,
115
- Err ( error) => {
116
- error ! ( "error in writer: {error}" ) ;
117
- Err ( error)
118
- }
119
- } ;
120
- // This should always succeed in sending unless we're already shutting
121
- // down for some other reason.
122
- let _: Result < _ , _ > = connection_tx. send ( to_return) . await ;
123
- } ) ;
122
+ tokio:: task:: spawn (
123
+ async move {
124
+ let to_return = match task ( & mut connection) . await {
125
+ Ok ( ( ) ) => Ok ( connection) ,
126
+ Err ( error) => {
127
+ error ! ( "error in writer: {error}" ) ;
128
+ Err ( error)
129
+ }
130
+ } ;
131
+ // This should always succeed in sending unless we're already shutting
132
+ // down for some other reason.
133
+ let _: Result < _ , _ > = connection_tx. send ( to_return) . await ;
134
+ }
135
+ . instrument ( tracing:: debug_span!( "spawn_with_connection" ) ) ,
136
+ ) ;
124
137
125
138
Ok ( ( ) )
126
139
}
@@ -188,12 +201,52 @@ impl WriterConnectionPool {
188
201
}
189
202
}
190
203
204
+ /// Small utility to make sure `finish()` is called on all write buffers
205
+ /// before committing to the database.
206
+ #[ derive( Default ) ]
207
+ struct FinishChecker {
208
+ counter : Arc < AtomicU32 > ,
209
+ }
210
+
211
+ struct FinishCheckerHandle {
212
+ counter : Arc < AtomicU32 > ,
213
+ }
214
+
215
+ impl FinishChecker {
216
+ /// Acquire a new handle, for a task that should declare when it has
217
+ /// finished.
218
+ pub fn handle ( & self ) -> FinishCheckerHandle {
219
+ self . counter . fetch_add ( 1 , Ordering :: SeqCst ) ;
220
+ FinishCheckerHandle {
221
+ counter : Arc :: clone ( & self . counter ) ,
222
+ }
223
+ }
224
+
225
+ /// Check that all handles have been declared as finished.
226
+ pub fn check_all_finished ( self ) -> Result < ( ) , Error > {
227
+ if self . counter . load ( Ordering :: SeqCst ) == 0 {
228
+ Ok ( ( ) )
229
+ } else {
230
+ Err ( Error :: WriteBuffersNotFinished )
231
+ }
232
+ }
233
+ }
234
+
235
+ impl FinishCheckerHandle {
236
+ /// Declare that the task this handle represents has been finished.
237
+ pub fn declare_finished ( self ) {
238
+ self . counter . fetch_sub ( 1 , Ordering :: SeqCst ) ;
239
+ }
240
+ }
241
+
191
242
pub struct MasWriter {
192
243
conn : LockedMasDatabase ,
193
244
writer_pool : WriterConnectionPool ,
194
245
195
246
indices_to_restore : Vec < IndexDescription > ,
196
247
constraints_to_restore : Vec < ConstraintDescription > ,
248
+
249
+ write_buffer_finish_checker : FinishChecker ,
197
250
}
198
251
199
252
pub struct MasNewUser {
@@ -337,7 +390,7 @@ impl MasWriter {
337
390
///
338
391
/// - If the database connection experiences an error.
339
392
#[ allow( clippy:: missing_panics_doc) ] // not real
340
- #[ tracing:: instrument( skip_all) ]
393
+ #[ tracing:: instrument( name = "syn2mas.mas_writer.new" , skip_all) ]
341
394
pub async fn new (
342
395
mut conn : LockedMasDatabase ,
343
396
mut writer_connections : Vec < PgConnection > ,
@@ -454,6 +507,7 @@ impl MasWriter {
454
507
writer_pool : WriterConnectionPool :: new ( writer_connections) ,
455
508
indices_to_restore,
456
509
constraints_to_restore,
510
+ write_buffer_finish_checker : FinishChecker :: default ( ) ,
457
511
} )
458
512
}
459
513
@@ -521,6 +575,8 @@ impl MasWriter {
521
575
/// - If the database connection experiences an error.
522
576
#[ tracing:: instrument( skip_all) ]
523
577
pub async fn finish ( mut self ) -> Result < PgConnection , Error > {
578
+ self . write_buffer_finish_checker . check_all_finished ( ) ?;
579
+
524
580
// Commit all writer transactions to the database.
525
581
self . writer_pool
526
582
. finish ( )
@@ -1041,28 +1097,24 @@ type WriteBufferFlusher<T> =
1041
1097
1042
1098
/// A buffer for writing rows to the MAS database.
1043
1099
/// Generic over the type of rows.
1044
- ///
1045
- /// # Panics
1046
- ///
1047
- /// Panics if dropped before `finish()` has been called.
1048
1100
pub struct MasWriteBuffer < T > {
1049
1101
rows : Vec < T > ,
1050
1102
flusher : WriteBufferFlusher < T > ,
1051
- finished : bool ,
1103
+ finish_checker_handle : FinishCheckerHandle ,
1052
1104
}
1053
1105
1054
1106
impl < T > MasWriteBuffer < T > {
1055
- pub fn new ( flusher : WriteBufferFlusher < T > ) -> Self {
1107
+ pub fn new ( writer : & MasWriter , flusher : WriteBufferFlusher < T > ) -> Self {
1056
1108
MasWriteBuffer {
1057
1109
rows : Vec :: with_capacity ( WRITE_BUFFER_BATCH_SIZE ) ,
1058
1110
flusher,
1059
- finished : false ,
1111
+ finish_checker_handle : writer . write_buffer_finish_checker . handle ( ) ,
1060
1112
}
1061
1113
}
1062
1114
1063
1115
pub async fn finish ( mut self , writer : & mut MasWriter ) -> Result < ( ) , Error > {
1064
- self . finished = true ;
1065
1116
self . flush ( writer) . await ?;
1117
+ self . finish_checker_handle . declare_finished ( ) ;
1066
1118
Ok ( ( ) )
1067
1119
}
1068
1120
@@ -1085,12 +1137,6 @@ impl<T> MasWriteBuffer<T> {
1085
1137
}
1086
1138
}
1087
1139
1088
- impl < T > Drop for MasWriteBuffer < T > {
1089
- fn drop ( & mut self ) {
1090
- assert ! ( self . finished, "MasWriteBuffer dropped but not finished!" ) ;
1091
- }
1092
- }
1093
-
1094
1140
#[ cfg( test) ]
1095
1141
mod test {
1096
1142
use std:: collections:: { BTreeMap , BTreeSet } ;
0 commit comments