@@ -15,12 +15,21 @@ use std::collections::HashMap;
15
15
use std:: sync:: Arc ;
16
16
use std:: time:: Duration ;
17
17
use tokio:: time:: delay_for;
18
+ use crate :: db:: { get_channel_by_id} ;
18
19
19
- pub ( crate ) type Aggregate = Arc < RwLock < HashMap < ChannelId , EventAggregate > > > ;
20
+
21
+ #[ derive( Debug ) ]
22
+ struct Record {
23
+ channel : Channel ,
24
+ aggregate : EventAggregate
25
+ }
26
+
27
+ type Recorder = Arc < RwLock < HashMap < ChannelId , Record > > > ;
20
28
21
29
#[ derive( Default , Clone ) ]
22
30
pub struct EventAggregator {
23
- aggregate : Aggregate ,
31
+ // aggregate: Aggregate,
32
+ recorder : Recorder
24
33
}
25
34
26
35
pub fn new_aggr ( channel_id : & ChannelId ) -> EventAggregate {
@@ -31,15 +40,19 @@ pub fn new_aggr(channel_id: &ChannelId) -> EventAggregate {
31
40
}
32
41
}
33
42
34
- async fn store ( db : & DbPool , channel_id : & ChannelId , logger : & Logger , aggr : Aggregate ) {
35
- let mut recorder = aggr . write ( ) . await ;
36
- let ev_aggr : Option < & EventAggregate > = recorder . get ( channel_id) ;
37
- if let Some ( data) = ev_aggr {
38
- if let Err ( e) = insert_event_aggregate ( & db, & channel_id, data) . await {
39
- error ! ( & logger, "{}" , e; "eventaggregator " => "store" ) ;
43
+ async fn store ( db : & DbPool , channel_id : & ChannelId , logger : & Logger , recorder : Recorder ) {
44
+ let mut channel_recorder = recorder . write ( ) . await ;
45
+ let record : Option < & Record > = channel_recorder . get ( channel_id) ;
46
+ if let Some ( data) = record {
47
+ if let Err ( e) = insert_event_aggregate ( & db, & channel_id, & data. aggregate ) . await {
48
+ error ! ( & logger, "{}" , e; "event_aggregator " => "store" ) ;
40
49
} else {
41
- // reset aggr
42
- recorder. insert ( channel_id. to_owned ( ) , new_aggr ( & channel_id) ) ;
50
+ // reset aggr record
51
+ let record = Record {
52
+ channel : data. channel . to_owned ( ) ,
53
+ aggregate : new_aggr ( & channel_id)
54
+ } ;
55
+ channel_recorder. insert ( channel_id. to_owned ( ) , record) ;
43
56
} ;
44
57
}
45
58
}
@@ -48,38 +61,38 @@ impl EventAggregator {
48
61
pub async fn record < ' a , A : Adapter > (
49
62
& self ,
50
63
app : & ' a Application < A > ,
51
- channel : & Channel ,
64
+ channel_id : & ChannelId ,
52
65
session : & Session ,
53
66
events : & ' a [ Event ] ,
54
67
) -> Result < ( ) , ResponseError > {
55
- let has_access = check_access (
56
- & app. redis ,
57
- & session,
58
- & app. config . ip_rate_limit ,
59
- & channel,
60
- events,
61
- )
62
- . await ;
63
- if let Err ( e) = has_access {
64
- return Err ( ResponseError :: BadRequest ( e. to_string ( ) ) ) ;
65
- }
66
-
67
- let mut recorder = self . aggregate . write ( ) . await ;
68
+ let recorder = self . recorder . clone ( ) ;
68
69
let aggr_throttle = app. config . aggr_throttle ;
69
70
let dbpool = app. pool . clone ( ) ;
70
- let aggregate = self . aggregate . clone ( ) ;
71
- let withdraw_period_start = channel. spec . withdraw_period_start ;
72
- let channel_id = channel. id ;
73
71
let logger = app. logger . clone ( ) ;
74
72
75
- let mut aggr: & mut EventAggregate = match recorder. get_mut ( & channel. id ) {
76
- Some ( aggr) => aggr,
73
+ let mut channel_recorder = self . recorder . write ( ) . await ;
74
+ let record: & mut Record = match channel_recorder. get_mut ( & channel_id) {
75
+ Some ( record) => record,
77
76
None => {
77
+ // fetch channel
78
+ let channel = get_channel_by_id ( & app. pool , & channel_id)
79
+ . await ?
80
+ . ok_or_else ( || ResponseError :: NotFound ) ?;
81
+
82
+ let withdraw_period_start = channel. spec . withdraw_period_start ;
83
+ let channel_id = channel. id ;
84
+ let record = Record {
85
+ channel,
86
+ aggregate : new_aggr ( & channel_id)
87
+ } ;
88
+
78
89
// insert into
79
- recorder . insert ( channel . id , new_aggr ( & channel . id ) ) ;
90
+ channel_recorder . insert ( channel_id . to_owned ( ) , record ) ;
80
91
92
+ //
81
93
// spawn async task that persists
82
94
// the channel events to database
95
+ let recorder = recorder. clone ( ) ;
83
96
if aggr_throttle > 0 {
84
97
tokio:: spawn ( async move {
85
98
loop {
@@ -93,31 +106,43 @@ impl EventAggregator {
93
106
}
94
107
95
108
delay_for ( Duration :: from_secs ( aggr_throttle as u64 ) ) . await ;
96
- store ( & dbpool, & channel_id, & logger, aggregate . clone ( ) ) . await ;
109
+ store ( & dbpool, & channel_id, & logger, recorder . clone ( ) ) . await ;
97
110
}
98
111
} ) ;
99
112
}
100
113
101
- recorder
102
- . get_mut ( & channel . id )
114
+ channel_recorder
115
+ . get_mut ( & channel_id )
103
116
. expect ( "should have aggr, we just inserted" )
104
117
}
105
118
} ;
106
119
120
+ let has_access = check_access (
121
+ & app. redis ,
122
+ & session,
123
+ & app. config . ip_rate_limit ,
124
+ & record. channel ,
125
+ events,
126
+ )
127
+ . await ;
128
+ if let Err ( e) = has_access {
129
+ return Err ( ResponseError :: BadRequest ( e. to_string ( ) ) ) ;
130
+ }
131
+
107
132
events
108
133
. iter ( )
109
- . for_each ( |ev| event_reducer:: reduce ( & channel, & mut aggr , ev) ) ;
134
+ . for_each ( |ev| event_reducer:: reduce ( & record . channel , & mut record . aggregate , ev) ) ;
110
135
111
136
// drop write access to RwLock
112
137
// this is required to prevent a deadlock in store
113
- drop ( recorder ) ;
138
+ drop ( channel_recorder ) ;
114
139
115
140
if aggr_throttle == 0 {
116
141
store (
117
142
& app. pool ,
118
- & channel . id ,
143
+ & channel_id ,
119
144
& app. logger . clone ( ) ,
120
- self . aggregate . clone ( ) ,
145
+ recorder . clone ( ) ,
121
146
)
122
147
. await ;
123
148
}
0 commit comments