Skip to content

Commit db415ec

Browse files
committed
fix: remove channel load middleware /:id/events
1 parent b94045a commit db415ec

File tree

3 files changed

+68
-44
lines changed

3 files changed

+68
-44
lines changed

sentry/src/event_aggregator.rs

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,21 @@ use std::collections::HashMap;
1515
use std::sync::Arc;
1616
use std::time::Duration;
1717
use tokio::time::delay_for;
18+
use crate::db::{get_channel_by_id};
1819

19-
pub(crate) type Aggregate = Arc<RwLock<HashMap<ChannelId, EventAggregate>>>;
20+
21+
#[derive(Debug)]
22+
struct Record {
23+
channel: Channel,
24+
aggregate: EventAggregate
25+
}
26+
27+
type Recorder = Arc<RwLock<HashMap<ChannelId, Record>>>;
2028

2129
#[derive(Default, Clone)]
2230
pub struct EventAggregator {
23-
aggregate: Aggregate,
31+
// aggregate: Aggregate,
32+
recorder: Recorder
2433
}
2534

2635
pub fn new_aggr(channel_id: &ChannelId) -> EventAggregate {
@@ -31,15 +40,19 @@ pub fn new_aggr(channel_id: &ChannelId) -> EventAggregate {
3140
}
3241
}
3342

34-
async fn store(db: &DbPool, channel_id: &ChannelId, logger: &Logger, aggr: Aggregate) {
35-
let mut recorder = aggr.write().await;
36-
let ev_aggr: Option<&EventAggregate> = recorder.get(channel_id);
37-
if let Some(data) = ev_aggr {
38-
if let Err(e) = insert_event_aggregate(&db, &channel_id, data).await {
39-
error!(&logger, "{}", e; "eventaggregator" => "store");
43+
async fn store(db: &DbPool, channel_id: &ChannelId, logger: &Logger, recorder: Recorder) {
44+
let mut channel_recorder = recorder.write().await;
45+
let record: Option<&Record> = channel_recorder.get(channel_id);
46+
if let Some(data) = record {
47+
if let Err(e) = insert_event_aggregate(&db, &channel_id, &data.aggregate).await {
48+
error!(&logger, "{}", e; "event_aggregator" => "store");
4049
} else {
41-
// reset aggr
42-
recorder.insert(channel_id.to_owned(), new_aggr(&channel_id));
50+
// reset aggr record
51+
let record = Record {
52+
channel: data.channel.to_owned(),
53+
aggregate: new_aggr(&channel_id)
54+
};
55+
channel_recorder.insert(channel_id.to_owned(), record);
4356
};
4457
}
4558
}
@@ -48,38 +61,38 @@ impl EventAggregator {
4861
pub async fn record<'a, A: Adapter>(
4962
&self,
5063
app: &'a Application<A>,
51-
channel: &Channel,
64+
channel_id: &ChannelId,
5265
session: &Session,
5366
events: &'a [Event],
5467
) -> Result<(), ResponseError> {
55-
let has_access = check_access(
56-
&app.redis,
57-
&session,
58-
&app.config.ip_rate_limit,
59-
&channel,
60-
events,
61-
)
62-
.await;
63-
if let Err(e) = has_access {
64-
return Err(ResponseError::BadRequest(e.to_string()));
65-
}
66-
67-
let mut recorder = self.aggregate.write().await;
68+
let recorder = self.recorder.clone();
6869
let aggr_throttle = app.config.aggr_throttle;
6970
let dbpool = app.pool.clone();
70-
let aggregate = self.aggregate.clone();
71-
let withdraw_period_start = channel.spec.withdraw_period_start;
72-
let channel_id = channel.id;
7371
let logger = app.logger.clone();
7472

75-
let mut aggr: &mut EventAggregate = match recorder.get_mut(&channel.id) {
76-
Some(aggr) => aggr,
73+
let mut channel_recorder = self.recorder.write().await;
74+
let record: &mut Record = match channel_recorder.get_mut(&channel_id) {
75+
Some(record) => record,
7776
None => {
77+
// fetch channel
78+
let channel = get_channel_by_id(&app.pool, &channel_id)
79+
.await?
80+
.ok_or_else(|| ResponseError::NotFound)?;
81+
82+
let withdraw_period_start = channel.spec.withdraw_period_start;
83+
let channel_id = channel.id;
84+
let record = Record {
85+
channel,
86+
aggregate: new_aggr(&channel_id)
87+
};
88+
7889
// insert into
79-
recorder.insert(channel.id, new_aggr(&channel.id));
90+
channel_recorder.insert(channel_id.to_owned(), record);
8091

92+
//
8193
// spawn async task that persists
8294
// the channel events to database
95+
let recorder = recorder.clone();
8396
if aggr_throttle > 0 {
8497
tokio::spawn(async move {
8598
loop {
@@ -93,31 +106,43 @@ impl EventAggregator {
93106
}
94107

95108
delay_for(Duration::from_secs(aggr_throttle as u64)).await;
96-
store(&dbpool, &channel_id, &logger, aggregate.clone()).await;
109+
store(&dbpool, &channel_id, &logger, recorder.clone()).await;
97110
}
98111
});
99112
}
100113

101-
recorder
102-
.get_mut(&channel.id)
114+
channel_recorder
115+
.get_mut(&channel_id)
103116
.expect("should have aggr, we just inserted")
104117
}
105118
};
106119

120+
let has_access = check_access(
121+
&app.redis,
122+
&session,
123+
&app.config.ip_rate_limit,
124+
&record.channel,
125+
events,
126+
)
127+
.await;
128+
if let Err(e) = has_access {
129+
return Err(ResponseError::BadRequest(e.to_string()));
130+
}
131+
107132
events
108133
.iter()
109-
.for_each(|ev| event_reducer::reduce(&channel, &mut aggr, ev));
134+
.for_each(|ev| event_reducer::reduce(&record.channel, &mut record.aggregate, ev));
110135

111136
// drop write access to RwLock
112137
// this is required to prevent a deadlock in store
113-
drop(recorder);
138+
drop(channel_recorder);
114139

115140
if aggr_throttle == 0 {
116141
store(
117142
&app.pool,
118-
&channel.id,
143+
&channel_id,
119144
&app.logger.clone(),
120-
self.aggregate.clone(),
145+
recorder.clone(),
121146
)
122147
.await;
123148
}

sentry/src/lib.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,6 @@ async fn channels_router<A: Adapter + 'static>(
264264
]);
265265
req.extensions_mut().insert(param);
266266

267-
let req = chain(req, app, vec![Box::new(channel_load)]).await?;
268-
269267
list_channel_event_aggregates(req, app).await
270268
} else if let (Some(caps), &Method::POST) =
271269
(CREATE_EVENTS_BY_CHANNEL_ID.captures(&path), method)

sentry/src/routes/channel.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ pub async fn insert_events<A: Adapter + 'static>(
114114
.expect("request session")
115115
.to_owned();
116116

117-
let channel = req
117+
let route_params = req
118118
.extensions()
119-
.get::<Channel>()
120-
.expect("Request should have Channel")
121-
.to_owned();
119+
.get::<RouteParams>()
120+
.expect("request should have route params");
121+
122+
let channel_id = ChannelId::from_hex(route_params.index(0))?;
122123

123124
let into_body = req.into_body();
124125
let body = hyper::body::to_bytes(into_body).await?;
@@ -128,7 +129,7 @@ pub async fn insert_events<A: Adapter + 'static>(
128129
.ok_or_else(|| ResponseError::BadRequest("invalid request".to_string()))?;
129130

130131
app.event_aggregator
131-
.record(app, &channel, &session, &events)
132+
.record(app, &channel_id, &session, &events)
132133
.await?;
133134

134135
Ok(Response::builder()

0 commit comments

Comments
 (0)