Skip to content

Commit 94e6759

Browse files
committed
Make the TTL and max payload size configurable
1 parent d73992d commit 94e6759

File tree

4 files changed

+87
-40
lines changed

4 files changed

+87
-40
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ sha2 = "0.10.6"
1919
time = "0.3.15"
2020
tokio = { version = "1.21.2", features = ["sync", "time"] }
2121
tower = { version = "0.4.13", features = ["util"] }
22-
tower-http = { version = "0.3.4", features = ["cors"] }
22+
tower-http = { version = "0.3.4", features = ["cors", "limit", "set-header"] }
2323
tracing = "0.1.37"
2424
uuid = { version = "1.1.2", features = ["v4", "fast-rng", "serde"] }
2525

server/src/main.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
#![warn(clippy::pedantic)]
1818

1919
use clap::Parser;
20-
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
20+
use std::{
21+
net::{IpAddr, Ipv4Addr, SocketAddr},
22+
time::Duration,
23+
};
2124

2225
#[derive(Parser)]
2326
struct Options {
@@ -32,6 +35,14 @@ struct Options {
3235
/// Path prefix on which to mount the rendez-vous server
3336
#[arg(long)]
3437
prefix: Option<String>,
38+
39+
/// Time to live of entries, in seconds
40+
#[arg(short, long, default_value_t = 60)]
41+
ttl: u64,
42+
43+
/// Maximum payload size, in bytes
44+
#[arg(short, long, default_value_t = 4096)]
45+
max_bytes: usize,
3546
}
3647

3748
#[tokio::main]
@@ -40,12 +51,18 @@ async fn main() {
4051

4152
let options = Options::parse();
4253
let prefix = options.prefix.unwrap_or_default();
54+
let ttl = Duration::from_secs(options.ttl);
4355

4456
let addr = SocketAddr::from((options.address, options.port));
4557

46-
let service = matrix_http_rendezvous::router(&prefix);
58+
let service = matrix_http_rendezvous::router(&prefix, ttl, options.max_bytes);
4759

4860
tracing::info!("Listening on http://{addr}");
61+
tracing::info!(
62+
"TTL: {ttl}s – Maximum payload size: {max_bytes} bytes",
63+
ttl = ttl.as_secs(),
64+
max_bytes = options.max_bytes
65+
);
4966

5067
hyper::Server::bind(&addr)
5168
.serve(service.into_make_service())

src/lib.rs

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use std::{
2626

2727
use axum::{
2828
body::HttpBody,
29-
extract::{ContentLengthLimit, FromRef, Path, State},
29+
extract::{FromRef, Path, State},
3030
http::{
3131
header::{CONTENT_TYPE, ETAG, IF_MATCH, IF_NONE_MATCH, LOCATION},
3232
StatusCode,
@@ -37,17 +37,19 @@ use axum::{
3737
};
3838
use base64ct::Encoding;
3939
use bytes::Bytes;
40-
use headers::{ContentType, ETag, Expires, HeaderName, IfMatch, IfNoneMatch, LastModified};
40+
use headers::{
41+
ContentType, ETag, Expires, HeaderName, HeaderValue, IfMatch, IfNoneMatch, LastModified,
42+
};
4143
use mime::Mime;
4244
use sha2::Digest;
4345
use tokio::sync::RwLock;
44-
use tower_http::cors::{Any, CorsLayer};
46+
use tower_http::{
47+
cors::{Any, CorsLayer},
48+
limit::RequestBodyLimitLayer,
49+
set_header::SetResponseHeaderLayer,
50+
};
4551
use uuid::Uuid;
4652

47-
// TODO: config?
48-
const MAX_BYTES: u64 = 4096;
49-
const TTL: Duration = Duration::from_secs(60);
50-
5153
struct Session {
5254
hash: [u8; 32],
5355
data: Bytes,
@@ -57,14 +59,14 @@ struct Session {
5759
}
5860

5961
impl Session {
60-
fn new(data: Bytes, content_type: Mime) -> Self {
62+
fn new(data: Bytes, content_type: Mime, ttl: Duration) -> Self {
6163
let hash = sha2::Sha256::digest(&data).into();
6264
let now = SystemTime::now();
6365
Self {
6466
hash,
6567
data,
6668
content_type,
67-
expires: now + TTL,
69+
expires: now + ttl,
6870
last_modified: now,
6971
}
7072
}
@@ -113,6 +115,7 @@ impl Session {
113115
struct Sessions {
114116
// TODO: is that global lock alright?
115117
inner: Arc<RwLock<HashMap<Uuid, Session>>>,
118+
ttl: Duration,
116119
}
117120

118121
impl Sessions {
@@ -153,25 +156,19 @@ impl AppState {
153156
async fn new_session(
154157
State(sessions): State<Sessions>,
155158
content_type: Option<TypedHeader<ContentType>>,
156-
// TODO: this requires a Content-Length header, is that alright?
157-
ContentLengthLimit(payload): ContentLengthLimit<Bytes, MAX_BYTES>,
159+
payload: Bytes,
158160
) -> impl IntoResponse {
161+
let ttl = sessions.ttl;
159162
// TODO: should we use something else? Check for colisions?
160163
let id = Uuid::new_v4();
161164
let content_type =
162165
content_type.map_or(mime::APPLICATION_OCTET_STREAM, |TypedHeader(c)| c.into());
163-
let session = Session::new(payload, content_type);
166+
let session = Session::new(payload, content_type, ttl);
164167
let headers = session.typed_headers();
165-
sessions.insert(id, session, TTL).await;
168+
sessions.insert(id, session, ttl).await;
166169

167170
let location = id.to_string();
168-
let additional_headers = [
169-
(LOCATION, location),
170-
(
171-
HeaderName::from_static("x-max-bytes"),
172-
MAX_BYTES.to_string(),
173-
),
174-
];
171+
let additional_headers = [(LOCATION, location)];
175172
(StatusCode::CREATED, headers, additional_headers)
176173
}
177174

@@ -188,7 +185,7 @@ async fn update_session(
188185
Path(id): Path<Uuid>,
189186
content_type: Option<TypedHeader<ContentType>>,
190187
if_match: Option<TypedHeader<IfMatch>>,
191-
ContentLengthLimit(payload): ContentLengthLimit<Bytes, MAX_BYTES>,
188+
payload: Bytes,
192189
) -> Response {
193190
if let Some(session) = sessions.write().await.get_mut(&id) {
194191
if let Some(TypedHeader(if_match)) = if_match {
@@ -235,13 +232,16 @@ async fn get_session(
235232
}
236233

237234
#[must_use]
238-
pub fn router<B>(prefix: &str) -> Router<(), B>
235+
pub fn router<B>(prefix: &str, ttl: Duration, max_bytes: usize) -> Router<(), B>
239236
where
240237
B: HttpBody + Send + 'static,
241238
<B as HttpBody>::Data: Send,
242239
<B as HttpBody>::Error: std::error::Error + Send + Sync,
243240
{
244-
let sessions = Sessions::default();
241+
let sessions = Sessions {
242+
inner: Arc::default(),
243+
ttl,
244+
};
245245

246246
let state = AppState::new(sessions);
247247
let router = Router::with_state(state)
@@ -251,13 +251,21 @@ where
251251
get(get_session).put(update_session).delete(delete_session),
252252
);
253253

254-
Router::new().nest(prefix, router).layer(
255-
CorsLayer::new()
256-
.allow_origin(Any)
257-
.allow_methods(Any)
258-
.allow_headers([CONTENT_TYPE, IF_MATCH, IF_NONE_MATCH])
259-
.expose_headers([ETAG, LOCATION, HeaderName::from_static("x-max-bytes")]),
260-
)
254+
Router::new()
255+
.nest(prefix, router)
256+
.layer(RequestBodyLimitLayer::new(max_bytes))
257+
.layer(SetResponseHeaderLayer::if_not_present(
258+
HeaderName::from_static("x-max-bytes"),
259+
HeaderValue::from_str(&max_bytes.to_string())
260+
.expect("Could not construct x-max-bytes header value"),
261+
))
262+
.layer(
263+
CorsLayer::new()
264+
.allow_origin(Any)
265+
.allow_methods(Any)
266+
.allow_headers([CONTENT_TYPE, IF_MATCH, IF_NONE_MATCH])
267+
.expose_headers([ETAG, LOCATION, HeaderName::from_static("x-max-bytes")]),
268+
)
261269
}
262270

263271
#[cfg(test)]
@@ -280,7 +288,8 @@ mod tests {
280288

281289
#[tokio::test]
282290
async fn test_post_and_get() {
283-
let app = router("/");
291+
let ttl = Duration::from_secs(60);
292+
let app = router("/", ttl, 4096);
284293

285294
let body = r#"{"hello": "world"}"#.to_string();
286295
let request = Request::post("/")
@@ -306,7 +315,7 @@ mod tests {
306315
assert_eq!(response.headers().get(ETAG).unwrap(), etag);
307316

308317
// Let the entry expire
309-
advance_time(TTL + Duration::from_secs(1)).await;
318+
advance_time(ttl + Duration::from_secs(1)).await;
310319

311320
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
312321
assert_eq!(&body[..], br#"{"hello": "world"}"#);
@@ -318,7 +327,8 @@ mod tests {
318327

319328
#[tokio::test]
320329
async fn test_post_and_get_if_none_match() {
321-
let app = router("/");
330+
let ttl = Duration::from_secs(60);
331+
let app = router("/", ttl, 4096);
322332

323333
let body = r#"{"hello": "world"}"#.to_string();
324334
let request = Request::post("/")
@@ -344,7 +354,8 @@ mod tests {
344354

345355
#[tokio::test]
346356
async fn test_post_and_put() {
347-
let app = router("/");
357+
let ttl = Duration::from_secs(60);
358+
let app = router("/", ttl, 4096);
348359

349360
let body = r#"{"hello": "world"}"#.to_string();
350361
let request = Request::post("/")
@@ -370,7 +381,8 @@ mod tests {
370381

371382
#[tokio::test]
372383
async fn test_post_and_put_if_match() {
373-
let app = router("/");
384+
let ttl = Duration::from_secs(60);
385+
let app = router("/", ttl, 4096);
374386

375387
let body = r#"{"hello": "world"}"#.to_string();
376388
let request = Request::post("/")
@@ -407,7 +419,8 @@ mod tests {
407419

408420
#[tokio::test]
409421
async fn test_post_delete_and_get() {
410-
let app = router("/");
422+
let ttl = Duration::from_secs(60);
423+
let app = router("/", ttl, 4096);
411424

412425
let body = r#"{"hello": "world"}"#.to_string();
413426
let request = Request::post("/")

synapse/src/lib.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#![warn(clippy::pedantic)]
1818
#![allow(clippy::needless_pass_by_value)]
1919

20+
use std::time::Duration;
21+
2022
use anyhow::anyhow;
2123
use http_body::Body;
2224
use pyo3::prelude::*;
@@ -25,10 +27,24 @@ use tower::ServiceExt;
2527

2628
use pyo3_matrix_synapse_module::{parse_config, ModuleApi};
2729

30+
fn default_ttl() -> u64 {
31+
60
32+
}
33+
34+
fn default_max_bytes() -> usize {
35+
4096
36+
}
37+
2838
#[pyclass]
2939
#[derive(Deserialize)]
3040
struct Config {
3141
prefix: String,
42+
43+
#[serde(default = "default_ttl")]
44+
ttl: u64,
45+
46+
#[serde(default = "default_max_bytes")]
47+
max_bytes: usize,
3248
}
3349

3450
#[pyclass]
@@ -38,7 +54,8 @@ pub struct SynapseRendezvousModule;
3854
impl SynapseRendezvousModule {
3955
#[new]
4056
fn new(config: &Config, module_api: ModuleApi) -> PyResult<Self> {
41-
let service = matrix_http_rendezvous::router(&config.prefix)
57+
let ttl = Duration::from_secs(config.ttl);
58+
let service = matrix_http_rendezvous::router(&config.prefix, ttl, config.max_bytes)
4259
.map_response(|res| res.map(|b| b.map_err(|e| anyhow!(e))));
4360

4461
module_api.register_web_service(&config.prefix, service)?;

0 commit comments

Comments
 (0)