Skip to content

Commit 9f40de2

Browse files
committed
Parse TTL and max-bytes parameters in a human-friendly way
1 parent 94e6759 commit 9f40de2

File tree

5 files changed

+160
-22
lines changed

5 files changed

+160
-22
lines changed

server/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ repository = "https://github.com/matrix-org/rust-http-rendezvous-server/"
99
rust-version = "1.61"
1010

1111
[dependencies]
12+
bytesize = "1.1.0"
13+
clap = { version = "4.0.14", features = ["derive"] }
14+
humantime = "2.1.0"
1215
hyper = { version = "0.14.20", features = ["server"] }
1316
tokio = { version = "1.21.2", features = ["macros", "rt-multi-thread"] }
1417
tracing = "0.1.37"
1518
tracing-subscriber = "0.3.16"
16-
clap = { version = "4.0.14", features = ["derive"] }
1719

1820
matrix-http-rendezvous = { path = "../", version = "0.1.3" }

server/src/main.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#![deny(clippy::all)]
1717
#![warn(clippy::pedantic)]
1818

19+
use bytesize::ByteSize;
1920
use clap::Parser;
2021
use std::{
2122
net::{IpAddr, Ipv4Addr, SocketAddr},
@@ -37,12 +38,12 @@ struct Options {
3738
prefix: Option<String>,
3839

3940
/// Time to live of entries, in seconds
40-
#[arg(short, long, default_value_t = 60)]
41-
ttl: u64,
41+
#[arg(short, long, default_value_t = Duration::from_secs(60).into())]
42+
ttl: humantime::Duration,
4243

4344
/// Maximum payload size, in bytes
44-
#[arg(short, long, default_value_t = 4096)]
45-
max_bytes: usize,
45+
#[arg(short, long, default_value = "4KiB")]
46+
max_bytes: ByteSize,
4647
}
4748

4849
#[tokio::main]
@@ -51,17 +52,22 @@ async fn main() {
5152

5253
let options = Options::parse();
5354
let prefix = options.prefix.unwrap_or_default();
54-
let ttl = Duration::from_secs(options.ttl);
55+
let ttl = options.ttl.into();
56+
let max_bytes = options
57+
.max_bytes
58+
.0
59+
.try_into()
60+
.expect("Max bytes size too large");
5561

5662
let addr = SocketAddr::from((options.address, options.port));
5763

58-
let service = matrix_http_rendezvous::router(&prefix, ttl, options.max_bytes);
64+
let service = matrix_http_rendezvous::router(&prefix, ttl, max_bytes);
5965

6066
tracing::info!("Listening on http://{addr}");
6167
tracing::info!(
62-
"TTL: {ttl}s – Maximum payload size: {max_bytes} bytes",
63-
ttl = ttl.as_secs(),
64-
max_bytes = options.max_bytes
68+
"TTL: {ttl} – Maximum payload size: {max_bytes}",
69+
ttl = humantime::format_duration(ttl),
70+
max_bytes = options.max_bytes.to_string_as(true)
6571
);
6672

6773
hyper::Server::bind(&addr)

src/lib.rs

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use std::{
2626

2727
use axum::{
2828
body::HttpBody,
29-
extract::{FromRef, Path, State},
29+
extract::{DefaultBodyLimit, FromRef, Path, State},
3030
http::{
3131
header::{CONTENT_TYPE, ETAG, IF_MATCH, IF_NONE_MATCH, LOCATION},
3232
StatusCode,
@@ -253,6 +253,7 @@ where
253253

254254
Router::new()
255255
.nest(prefix, router)
256+
.layer(DefaultBodyLimit::disable())
256257
.layer(RequestBodyLimitLayer::new(max_bytes))
257258
.layer(SetResponseHeaderLayer::if_not_present(
258259
HeaderName::from_static("x-max-bytes"),
@@ -270,14 +271,71 @@ where
270271

271272
#[cfg(test)]
272273
mod tests {
274+
use std::convert::Infallible;
275+
273276
use super::*;
274277

275278
use axum::http::{
276279
header::{CONTENT_LENGTH, CONTENT_TYPE},
277280
Request,
278281
};
282+
use bytes::Buf;
279283
use tower::util::ServiceExt;
280284

285+
/// A slow body, which sends the bytes in small chunks (1 byte per chunk by default)
286+
#[derive(Clone)]
287+
struct SlowBody {
288+
body: Bytes,
289+
chunk_size: usize,
290+
}
291+
292+
impl SlowBody {
293+
const fn from_static(bytes: &'static [u8]) -> Self {
294+
Self {
295+
body: Bytes::from_static(bytes),
296+
chunk_size: 1,
297+
}
298+
}
299+
300+
const fn from_bytes(body: Bytes) -> Self {
301+
Self {
302+
body,
303+
chunk_size: 1,
304+
}
305+
}
306+
307+
const fn with_chunk_size(mut self, chunk_size: usize) -> Self {
308+
self.chunk_size = chunk_size;
309+
self
310+
}
311+
}
312+
313+
impl HttpBody for SlowBody {
314+
type Data = Bytes;
315+
type Error = Infallible;
316+
317+
fn poll_data(
318+
self: std::pin::Pin<&mut Self>,
319+
_cx: &mut std::task::Context<'_>,
320+
) -> std::task::Poll<Option<Result<Self::Data, Self::Error>>> {
321+
if self.body.is_empty() {
322+
std::task::Poll::Ready(None)
323+
} else {
324+
let size = self.chunk_size.min(self.body.len());
325+
let ret = self.body.slice(0..size);
326+
self.get_mut().body.advance(size);
327+
std::task::Poll::Ready(Some(Ok(ret)))
328+
}
329+
}
330+
331+
fn poll_trailers(
332+
self: std::pin::Pin<&mut Self>,
333+
_cx: &mut std::task::Context<'_>,
334+
) -> std::task::Poll<Result<Option<headers::HeaderMap>, Self::Error>> {
335+
std::task::Poll::Ready(Ok(None))
336+
}
337+
}
338+
281339
async fn advance_time(duration: Duration) {
282340
tokio::task::yield_now().await;
283341
tokio::time::pause();
@@ -325,6 +383,63 @@ mod tests {
325383
assert_eq!(response.status(), StatusCode::NOT_FOUND);
326384
}
327385

386+
#[tokio::test]
387+
async fn test_post_max_bytes() {
388+
let ttl = Duration::from_secs(60);
389+
390+
let body = br#"{"hello": "world"}"#;
391+
392+
// It doesn't work with a way too small size
393+
let slow_body = SlowBody::from_static(body);
394+
let request = Request::post("/")
395+
.header(CONTENT_TYPE, "application/json")
396+
.body(slow_body)
397+
.unwrap();
398+
let response = router("/", ttl, 8).oneshot(request).await.unwrap();
399+
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
400+
401+
// It works with exactly the right size
402+
let slow_body = SlowBody::from_static(body);
403+
let request = Request::post("/")
404+
.header(CONTENT_TYPE, "application/json")
405+
.body(slow_body)
406+
.unwrap();
407+
let response = router("/", ttl, body.len()).oneshot(request).await.unwrap();
408+
assert_eq!(response.status(), StatusCode::CREATED);
409+
410+
// It doesn't work even if the size is one too short
411+
let slow_body = SlowBody::from_static(body);
412+
let request = Request::post("/")
413+
.header(CONTENT_TYPE, "application/json")
414+
.body(slow_body)
415+
.unwrap();
416+
let response = router("/", ttl, body.len() - 1)
417+
.oneshot(request)
418+
.await
419+
.unwrap();
420+
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
421+
422+
// Try with a big body (4MB), sent in small 128 bytes chunks
423+
let body = vec![42; 4 * 1024 * 1024].into_boxed_slice();
424+
let slow_body = SlowBody::from_bytes(Bytes::from(body)).with_chunk_size(128);
425+
let request = Request::post("/").body(slow_body).unwrap();
426+
let response = router("/", ttl, 4 * 1024 * 1024)
427+
.oneshot(request)
428+
.await
429+
.unwrap();
430+
assert_eq!(response.status(), StatusCode::CREATED);
431+
432+
// Try with a big body (4MB + 1B), sent in small 128 bytes chunks
433+
let body = vec![42; 4 * 1024 * 1024 + 1].into_boxed_slice();
434+
let slow_body = SlowBody::from_bytes(Bytes::from(body)).with_chunk_size(128);
435+
let request = Request::post("/").body(slow_body).unwrap();
436+
let response = router("/", ttl, 4 * 1024 * 1024)
437+
.oneshot(request)
438+
.await
439+
.unwrap();
440+
assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
441+
}
442+
328443
#[tokio::test]
329444
async fn test_post_and_get_if_none_match() {
330445
let ttl = Duration::from_secs(60);

synapse/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ crate-type = ["cdylib"]
1515

1616
[dependencies]
1717
anyhow = "1.0.65"
18+
bytesize = { version = "1.1.0", features = ["serde"] }
1819
http-body = "0.4.5"
20+
humantime = "2.1.0"
21+
humantime-serde = "1.1.1"
1922
pyo3 = { version = "0.17.2", features = ["extension-module", "abi3-py37", "anyhow"] }
2023
pyo3-log = "0.7.0"
2124
pyo3-matrix-synapse-module = "0.1.0"

synapse/src/lib.rs

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,32 +19,33 @@
1919

2020
use std::time::Duration;
2121

22-
use anyhow::anyhow;
22+
use anyhow::{anyhow, Context};
23+
use bytesize::ByteSize;
2324
use http_body::Body;
2425
use pyo3::prelude::*;
2526
use serde::Deserialize;
2627
use tower::ServiceExt;
2728

2829
use pyo3_matrix_synapse_module::{parse_config, ModuleApi};
2930

30-
fn default_ttl() -> u64 {
31-
60
31+
fn default_ttl() -> Duration {
32+
Duration::from_secs(60)
3233
}
3334

34-
fn default_max_bytes() -> usize {
35-
4096
35+
fn default_max_bytes() -> ByteSize {
36+
ByteSize::kib(4)
3637
}
3738

3839
#[pyclass]
3940
#[derive(Deserialize)]
4041
struct Config {
4142
prefix: String,
4243

43-
#[serde(default = "default_ttl")]
44-
ttl: u64,
44+
#[serde(default = "default_ttl", with = "humantime_serde")]
45+
ttl: Duration,
4546

4647
#[serde(default = "default_max_bytes")]
47-
max_bytes: usize,
48+
max_bytes: ByteSize,
4849
}
4950

5051
#[pyclass]
@@ -54,10 +55,21 @@ pub struct SynapseRendezvousModule;
5455
impl SynapseRendezvousModule {
5556
#[new]
5657
fn new(config: &Config, module_api: ModuleApi) -> PyResult<Self> {
57-
let ttl = Duration::from_secs(config.ttl);
58-
let service = matrix_http_rendezvous::router(&config.prefix, ttl, config.max_bytes)
59-
.map_response(|res| res.map(|b| b.map_err(|e| anyhow!(e))));
58+
tracing::info!(
59+
"Mounting rendez-vous server on {prefix}, with a TTL of {ttl} and max payload size of {max_bytes}",
60+
prefix = config.prefix,
61+
ttl = humantime::format_duration(config.ttl),
62+
max_bytes = config.max_bytes.to_string_as(true),
63+
);
64+
65+
let max_bytes = config
66+
.max_bytes
67+
.0
68+
.try_into()
69+
.context("Could not convert max_bytes from config")?;
6070

71+
let service = matrix_http_rendezvous::router(&config.prefix, config.ttl, max_bytes)
72+
.map_response(|res| res.map(|b| b.map_err(|e| anyhow!(e))));
6173
module_api.register_web_service(&config.prefix, service)?;
6274
Ok(Self)
6375
}

0 commit comments

Comments
 (0)