Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,683 changes: 1,505 additions & 178 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions changelog.d/18357.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Increase performance of introspecting access tokens when using delegated auth.
13 changes: 9 additions & 4 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "synapse"
version = "0.1.0"

edition = "2021"
rust-version = "1.66.0"
rust-version = "1.81.0"

[lib]
name = "synapse"
Expand All @@ -30,19 +30,24 @@ http = "1.1.0"
lazy_static = "1.4.0"
log = "0.4.17"
mime = "0.3.17"
pyo3 = { version = "0.23.5", features = [
pyo3 = { version = "0.24.2", features = [
"macros",
"anyhow",
"abi3",
"abi3-py39",
] }
pyo3-log = "0.12.0"
pythonize = "0.23.0"
pyo3-log = "0.12.3"
pythonize = "0.24.0"
regex = "1.6.0"
sha2 = "0.10.8"
serde = { version = "1.0.144", features = ["derive"] }
serde_json = "1.0.85"
ulid = "1.1.2"
reqwest = { version = "0.12.15", features = ["stream"] }
pyo3-async-runtimes = { version = "0.24.0", features = ["tokio-runtime"] }
http-body-util = "0.1.3"
futures = "0.3.31"
tokio = { version = "1.38.2", features = ["rt", "rt-multi-thread"] }

[features]
extension-module = ["pyo3/extension-module"]
Expand Down
Empty file added rust/src/async_twisted.rs
Empty file.
12 changes: 12 additions & 0 deletions rust/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,15 @@ impl NotFoundError {
NotFoundError::new_err(())
}
}

import_exception!(synapse.api.errors, HttpResponseException);

impl HttpResponseException {
pub fn new(status: StatusCode, bytes: Vec<u8>) -> pyo3::PyErr {
HttpResponseException::new_err((
status.as_u16(),
status.canonical_reason().unwrap_or_default(),
bytes,
))
}
}
196 changes: 196 additions & 0 deletions rust/src/http_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/*
* This file is licensed under the Affero General Public License (AGPL) version 3.
*
* Copyright (C) 2025 New Vector, Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* See the GNU Affero General Public License for more details:
* <https://www.gnu.org/licenses/agpl-3.0.html>.
*/

use std::{collections::HashMap, future::Future, panic::AssertUnwindSafe};

use anyhow::Context;
use futures::{FutureExt, TryStreamExt};
use lazy_static::lazy_static;
use pyo3::{exceptions::PyException, prelude::*, types::PyString};
use reqwest::RequestBuilder;
use tokio::runtime::Runtime;

use crate::errors::HttpResponseException;

lazy_static! {
static ref RUNTIME: Runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll likely want to have that configurable at some point, but this is probably a sane default.

.enable_all()
.build()
.unwrap();
static ref DEFERRED_CLASS: PyObject = {
Python::with_gil(|py| {
py.import("twisted.internet.defer")
.expect("module 'twisted.internet.defer' should be importable")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of panicking like that, not sure what will happen if it's the case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause a panic the first and subsequent derefs. Given that this shouldn't ever fail I prefer having the initialisation closer to the definition for clarities sake.

I've added an explicit call to these functions in the init of HttpClient so that if it ever did fail, it'd fail at startup.

.getattr("Deferred")
.expect("module 'twisted.internet.defer' should have a 'Deferred' class")
.unbind()
})
};
}

/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
let child_module: Bound<'_, PyModule> = PyModule::new(py, "http_client")?;
child_module.add_class::<HttpClient>()?;

m.add_submodule(&child_module)?;

// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import acl` work.
py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.http_client", child_module)?;

Ok(())
}

#[pyclass]
#[derive(Clone)]
struct HttpClient {
client: reqwest::Client,
}

#[pymethods]
impl HttpClient {
#[new]
pub fn py_new() -> HttpClient {
HttpClient {
client: reqwest::Client::new(),
}
}

pub fn get<'a>(
&self,
py: Python<'a>,
url: String,
response_limit: usize,
) -> PyResult<Bound<'a, PyAny>> {
self.send_request(py, self.client.get(url), response_limit)
}

pub fn post<'a>(
&self,
py: Python<'a>,
url: String,
response_limit: usize,
headers: HashMap<String, String>,
request_body: String,
) -> PyResult<Bound<'a, PyAny>> {
let mut builder = self.client.post(url);
for (name, value) in headers {
builder = builder.header(name, value);
}
builder = builder.body(request_body);

self.send_request(py, builder, response_limit)
}
}

impl HttpClient {
fn send_request<'a>(
&self,
py: Python<'a>,
builder: RequestBuilder,
response_limit: usize,
) -> PyResult<Bound<'a, PyAny>> {
create_deferred(py, async move {
let response = builder.send().await.context("sending request")?;

let status = response.status();

let mut stream = response.bytes_stream();
let mut buffer = Vec::new();
while let Some(chunk) = stream.try_next().await.context("reading body")? {
if buffer.len() + chunk.len() > response_limit {
Err(anyhow::anyhow!("Response size too large"))?;
}

buffer.extend_from_slice(&chunk);
}
Comment on lines +137 to +145
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you can achieve the same with http_body_util::Limited; reqwest::Response implements Into<Body>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, originally tried that but it messed up the errors (one of the exceptions stops implementing std Error). Given how straight forwards this is it felt easier than faffing with error types.


if !status.is_success() {
return Err(HttpResponseException::new(status, buffer));
}

let r = Python::with_gil(|py| buffer.into_pyobject(py).map(|o| o.unbind()))?;

Ok(r)
})
}
}

/// Creates a twisted deferred from the given future, spawning the task on the
/// tokio runtime.
///
/// Does not handle deferred cancellation or contextvars.
fn create_deferred<F, O>(py: Python, fut: F) -> PyResult<Bound<'_, PyAny>>
where
F: Future<Output = PyResult<O>> + Send + 'static,
for<'a> O: IntoPyObject<'a>,
{
let deferred = DEFERRED_CLASS.bind(py).call0()?;
let deferred_callback = deferred.getattr("callback")?.unbind();
let deferred_errback = deferred.getattr("errback")?.unbind();

let reactor = py.import("twisted.internet")?.getattr("reactor")?.unbind();

RUNTIME.spawn(async move {
// TODO: Is it safe to assert unwind safety here? I think so, as we
// don't use anything that could be tainted by the panic afterwards.
// Note that `.spawn(..)` asserts unwind safety on the future too.
let res = AssertUnwindSafe(fut).catch_unwind().await;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternatively, spawn the future on the runtime, await on the handle, and it will give you an Err with the panic in case it panics

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, though then you end up spawning two tasks per function, rather than one. Probably not a huge deal, but feels a bit bleurgh

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spawning is cheap, let's do that instead please. Also we'll need to spawn a separate task anyway if we want to properly support cancel

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on why you want to use new tasks please? I don't see the benefit of spawning a new task to just wait on it, semantically you end up with a bunch of tasks with different IDs all for the same work. In future, if we wanted to start tracking tasks and e.g. their resource usage then using multiple tasks makes that more complicated.

I also don't think we need a separate task for cancellation necessarily. You can change this line to do a select on both fut and the cancellation future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really confortable with AssertUnwindSafe being used so broadly. Tasks are cheap to spawn, and I don't think we'd want to base our potential resource consumption measurement in Rust-world to Tokio task IDs?

Anyway, even though AssertUnwindSafe smells like a bad thing waiting to happen, I won't block this PR further because of this if you're not convinced that spawning is fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really confortable with AssertUnwindSafe being used so broadly. Tasks are cheap to spawn, and I don't think we'd want to base our potential resource consumption measurement in Rust-world to Tokio task IDs?

I think we would do this at the task level, we'd have a task-local context that records resource usage, so you could e.g. wrap the top-level future that records the resource consumption of poll, or have DB functions record transaction times/etc. When spawning new tasks you'd want to decide if the resources of the task get allocated to the current task or to a new one.

Anyway, even though AssertUnwindSafe smells like a bad thing waiting to happen, I won't block this PR further because of this if you're not convinced that spawning is fine

Bear in mind that this is exactly what spawning a task does in tokio, so its hard to see how it would be fine for that and not here.


Python::with_gil(move |py| {
// Flatten the panic into standard python error
let res = match res {
Ok(r) => r,
Err(panic_err) => {
let panic_message = get_panic_message(&panic_err);
Err(PyException::new_err(
PyString::new(py, panic_message).unbind(),
))
}
};

// Send the result to the deferred, via `.callback(..)` or `.errback(..)`
match res {
Ok(obj) => {
reactor
.call_method(py, "callFromThread", (deferred_callback, obj), None)
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
}
Err(err) => {
reactor
.call_method(py, "callFromThread", (deferred_errback, err), None)
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
}
}
});
});

Ok(deferred)
}

/// Try and get the panic message out of the
fn get_panic_message<'a>(panic_err: &'a (dyn std::any::Any + Send + 'static)) -> &'a str {
// Apparently this is how you extract the panic message from a panic
if let Some(str_slice) = panic_err.downcast_ref::<&str>() {
str_slice
} else if let Some(string) = panic_err.downcast_ref::<String>() {
string
} else {
"unknown error"
}
}
2 changes: 2 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod acl;
pub mod errors;
pub mod events;
pub mod http;
pub mod http_client;
pub mod identifier;
pub mod matrix_const;
pub mod push;
Expand Down Expand Up @@ -50,6 +51,7 @@ fn synapse_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
acl::register_module(py, m)?;
push::register_module(py, m)?;
events::register_module(py, m)?;
http_client::register_module(py, m)?;
rendezvous::register_module(py, m)?;

Ok(())
Expand Down
50 changes: 24 additions & 26 deletions synapse/api/auth/msc3861_delegated.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from prometheus_client import Histogram

from twisted.web.client import readBody
from twisted.web.http_headers import Headers

from synapse.api.auth.base import BaseAuth
from synapse.api.errors import (
AuthError,
Expand All @@ -44,8 +41,14 @@
UnrecognizedRequestError,
)
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.opentracing import (
active_span,
force_tracing,
inject_request_headers,
start_active_span,
)
from synapse.synapse_rust.http_client import HttpClient
from synapse.types import Requester, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
Expand Down Expand Up @@ -164,6 +167,8 @@ class MSC3861DelegatedAuth(BaseAuth):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self._rust_http_client = HttpClient()

self._config = hs.config.experimental.msc3861
auth_method = MSC3861DelegatedAuth.AUTH_METHODS.get(
self._config.client_auth_method.value, None
Expand Down Expand Up @@ -316,38 +321,31 @@ async def _introspect_token(
uri, raw_headers, body = self._client_auth.prepare(
method="POST", uri=introspection_endpoint, headers=raw_headers, body=body
)
headers = Headers({k: [v] for (k, v) in raw_headers.items()})

# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to
# check the HTTP status code, and we do the body encoding ourselves.

logger.debug("Fetching token from MAS")
start_time = self._clock.time()
try:
response = await self._http_client.request(
method="POST",
uri=uri,
data=body.encode("utf-8"),
headers=headers,
)

resp_body = await make_deferred_yieldable(readBody(response))
with start_active_span("mas-introspect-token"):
inject_request_headers(raw_headers)
with PreserveLoggingContext():
resp_body = await self._rust_http_client.post(
uri, 1 * 1024 * 1024, raw_headers, body
)
except HttpResponseException as e:
end_time = self._clock.time()
introspection_response_timer.labels(e.code).observe(end_time - start_time)
raise
except Exception:
end_time = self._clock.time()
introspection_response_timer.labels("ERR").observe(end_time - start_time)
raise

end_time = self._clock.time()
introspection_response_timer.labels(response.code).observe(
end_time - start_time
)
logger.debug("Fetched token from MAS")

if response.code < 200 or response.code >= 300:
raise HttpResponseException(
response.code,
response.phrase.decode("ascii", errors="replace"),
resp_body,
)
end_time = self._clock.time()
introspection_response_timer.labels(200).observe(end_time - start_time)

resp = json_decoder.decode(resp_body.decode("utf-8"))

Expand Down
7 changes: 7 additions & 0 deletions synapse/logging/opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,13 @@ def inject_response_headers(response_headers: Headers) -> None:
response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")


@ensure_active_span("inject the span into a header dict")
def inject_request_headers(headers: Dict[str, str]) -> None:
span = opentracing.tracer.active_span
assert span is not None
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, headers)


@ensure_active_span(
"get the active span context as a dict", ret=cast(Dict[str, str], {})
)
Expand Down
24 changes: 24 additions & 0 deletions synapse/synapse_rust/http_client.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.

from typing import Mapping

class HttpClient:
def __init__(self) -> None: ...
async def get(self, url: str, response_limit: int) -> bytes: ...
async def post(
self,
url: str,
response_limit: int,
headers: Mapping[str, str],
request_body: str,
) -> bytes: ...
Loading