Skip to content

Commit f031105

Browse files
authored
Start and stop the Tokio runtime with the Twisted reactor (#18691)
Fixes #18659 This changes the Tokio runtime to be attached to the Twisted reactor. This way, the Tokio runtime starts when the Twisted reactor starts, and *not* when the module gets loaded. This is important as starting the runtime on module load meant that it broke when Synapse was started with `daemonize`/`synctl`, as forks only retain the calling threads, breaking the Tokio runtime. This also changes so that the HttpClient gets the Twisted reactor explicitly as parameter instead of loading it from `twisted.internet.reactor`
1 parent a0d6469 commit f031105

File tree

6 files changed

+166
-75
lines changed

6 files changed

+166
-75
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

changelog.d/18691.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix the MAS integration not working when Synapse is started with `--daemonize` or using `synctl`.

rust/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ reqwest = { version = "0.12.15", default-features = false, features = [
5252
http-body-util = "0.1.3"
5353
futures = "0.3.31"
5454
tokio = { version = "1.44.2", features = ["rt", "rt-multi-thread"] }
55+
once_cell = "1.18.0"
5556

5657
[features]
5758
extension-module = ["pyo3/extension-module"]

rust/src/http_client.rs

Lines changed: 158 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -12,58 +12,149 @@
1212
* <https://www.gnu.org/licenses/agpl-3.0.html>.
1313
*/
1414

15-
use std::{collections::HashMap, future::Future, panic::AssertUnwindSafe, sync::LazyLock};
15+
use std::{collections::HashMap, future::Future};
1616

1717
use anyhow::Context;
18-
use futures::{FutureExt, TryStreamExt};
19-
use pyo3::{exceptions::PyException, prelude::*, types::PyString};
18+
use futures::TryStreamExt;
19+
use once_cell::sync::OnceCell;
20+
use pyo3::{create_exception, exceptions::PyException, prelude::*};
2021
use reqwest::RequestBuilder;
2122
use tokio::runtime::Runtime;
2223

2324
use crate::errors::HttpResponseException;
2425

25-
/// The tokio runtime that we're using to run async Rust libs.
26-
static RUNTIME: LazyLock<Runtime> = LazyLock::new(|| {
27-
tokio::runtime::Builder::new_multi_thread()
28-
.worker_threads(4)
29-
.enable_all()
30-
.build()
31-
.unwrap()
32-
});
33-
34-
/// A reference to the `Deferred` python class.
35-
static DEFERRED_CLASS: LazyLock<PyObject> = LazyLock::new(|| {
36-
Python::with_gil(|py| {
37-
py.import("twisted.internet.defer")
38-
.expect("module 'twisted.internet.defer' should be importable")
39-
.getattr("Deferred")
40-
.expect("module 'twisted.internet.defer' should have a 'Deferred' class")
41-
.unbind()
42-
})
43-
});
44-
45-
/// A reference to the twisted `reactor`.
46-
static TWISTED_REACTOR: LazyLock<Py<PyModule>> = LazyLock::new(|| {
47-
Python::with_gil(|py| {
48-
py.import("twisted.internet.reactor")
49-
.expect("module 'twisted.internet.reactor' should be importable")
50-
.unbind()
51-
})
52-
});
26+
create_exception!(
27+
synapse.synapse_rust.http_client,
28+
RustPanicError,
29+
PyException,
30+
"A panic which happened in a Rust future"
31+
);
32+
33+
impl RustPanicError {
34+
fn from_panic(panic_err: &(dyn std::any::Any + Send + 'static)) -> PyErr {
35+
// Apparently this is how you extract the panic message from a panic
36+
let panic_message = if let Some(str_slice) = panic_err.downcast_ref::<&str>() {
37+
str_slice
38+
} else if let Some(string) = panic_err.downcast_ref::<String>() {
39+
string
40+
} else {
41+
"unknown error"
42+
};
43+
Self::new_err(panic_message.to_owned())
44+
}
45+
}
46+
47+
/// This is the name of the attribute where we store the runtime on the reactor
48+
static TOKIO_RUNTIME_ATTR: &str = "__synapse_rust_tokio_runtime";
49+
50+
/// A Python wrapper around a Tokio runtime.
51+
///
52+
/// This allows us to 'store' the runtime on the reactor instance, starting it
53+
/// when the reactor starts, and stopping it when the reactor shuts down.
54+
#[pyclass]
55+
struct PyTokioRuntime {
56+
runtime: Option<Runtime>,
57+
}
58+
59+
#[pymethods]
60+
impl PyTokioRuntime {
61+
fn start(&mut self) -> PyResult<()> {
62+
// TODO: allow customization of the runtime like the number of threads
63+
let runtime = tokio::runtime::Builder::new_multi_thread()
64+
.worker_threads(4)
65+
.enable_all()
66+
.build()?;
67+
68+
self.runtime = Some(runtime);
69+
70+
Ok(())
71+
}
72+
73+
fn shutdown(&mut self) -> PyResult<()> {
74+
let runtime = self
75+
.runtime
76+
.take()
77+
.context("Runtime was already shutdown")?;
78+
79+
// Dropping the runtime will shut it down
80+
drop(runtime);
81+
82+
Ok(())
83+
}
84+
}
85+
86+
impl PyTokioRuntime {
87+
/// Get the handle to the Tokio runtime, if it is running.
88+
fn handle(&self) -> PyResult<&tokio::runtime::Handle> {
89+
let handle = self
90+
.runtime
91+
.as_ref()
92+
.context("Tokio runtime is not running")?
93+
.handle();
94+
95+
Ok(handle)
96+
}
97+
}
98+
99+
/// Get a handle to the Tokio runtime stored on the reactor instance, or create
100+
/// a new one.
101+
fn runtime<'a>(reactor: &Bound<'a, PyAny>) -> PyResult<PyRef<'a, PyTokioRuntime>> {
102+
if !reactor.hasattr(TOKIO_RUNTIME_ATTR)? {
103+
install_runtime(reactor)?;
104+
}
105+
106+
get_runtime(reactor)
107+
}
108+
109+
/// Install a new Tokio runtime on the reactor instance.
110+
fn install_runtime(reactor: &Bound<PyAny>) -> PyResult<()> {
111+
let py = reactor.py();
112+
let runtime = PyTokioRuntime { runtime: None };
113+
let runtime = runtime.into_pyobject(py)?;
114+
115+
// Attach the runtime to the reactor, starting it when the reactor is
116+
// running, stopping it when the reactor is shutting down
117+
reactor.call_method1("callWhenRunning", (runtime.getattr("start")?,))?;
118+
reactor.call_method1(
119+
"addSystemEventTrigger",
120+
("after", "shutdown", runtime.getattr("shutdown")?),
121+
)?;
122+
reactor.setattr(TOKIO_RUNTIME_ATTR, runtime)?;
123+
124+
Ok(())
125+
}
126+
127+
/// Get a reference to a Tokio runtime handle stored on the reactor instance.
128+
fn get_runtime<'a>(reactor: &Bound<'a, PyAny>) -> PyResult<PyRef<'a, PyTokioRuntime>> {
129+
// This will raise if `TOKIO_RUNTIME_ATTR` is not set or if it is
130+
// not a `Runtime`. Careful that this could happen if the user sets it
131+
// manually, or if multiple versions of `pyo3-twisted` are used!
132+
let runtime: Bound<PyTokioRuntime> = reactor.getattr(TOKIO_RUNTIME_ATTR)?.extract()?;
133+
Ok(runtime.borrow())
134+
}
135+
136+
/// A reference to the `twisted.internet.defer` module.
137+
static DEFER: OnceCell<PyObject> = OnceCell::new();
138+
139+
/// Access to the `twisted.internet.defer` module.
140+
fn defer(py: Python<'_>) -> PyResult<&Bound<PyAny>> {
141+
Ok(DEFER
142+
.get_or_try_init(|| py.import("twisted.internet.defer").map(Into::into))?
143+
.bind(py))
144+
}
53145

54146
/// Called when registering modules with python.
55147
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
56148
let child_module: Bound<'_, PyModule> = PyModule::new(py, "http_client")?;
57149
child_module.add_class::<HttpClient>()?;
58150

59-
// Make sure we fail early if we can't build the lazy statics.
60-
LazyLock::force(&RUNTIME);
61-
LazyLock::force(&DEFERRED_CLASS);
151+
// Make sure we fail early if we can't load some modules
152+
defer(py)?;
62153

63154
m.add_submodule(&child_module)?;
64155

65156
// We need to manually add the module to sys.modules to make `from
66-
// synapse.synapse_rust import acl` work.
157+
// synapse.synapse_rust import http_client` work.
67158
py.import("sys")?
68159
.getattr("modules")?
69160
.set_item("synapse.synapse_rust.http_client", child_module)?;
@@ -72,26 +163,24 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
72163
}
73164

74165
#[pyclass]
75-
#[derive(Clone)]
76166
struct HttpClient {
77167
client: reqwest::Client,
168+
reactor: PyObject,
78169
}
79170

80171
#[pymethods]
81172
impl HttpClient {
82173
#[new]
83-
pub fn py_new(user_agent: &str) -> PyResult<HttpClient> {
84-
// The twisted reactor can only be imported after Synapse has been
85-
// imported, to allow Synapse to change the twisted reactor. If we try
86-
// and import the reactor too early twisted installs a default reactor,
87-
// which can't be replaced.
88-
LazyLock::force(&TWISTED_REACTOR);
174+
pub fn py_new(reactor: Bound<PyAny>, user_agent: &str) -> PyResult<HttpClient> {
175+
// Make sure the runtime gets installed
176+
let _ = runtime(&reactor)?;
89177

90178
Ok(HttpClient {
91179
client: reqwest::Client::builder()
92180
.user_agent(user_agent)
93181
.build()
94182
.context("building reqwest client")?,
183+
reactor: reactor.unbind(),
95184
})
96185
}
97186

@@ -129,7 +218,7 @@ impl HttpClient {
129218
builder: RequestBuilder,
130219
response_limit: usize,
131220
) -> PyResult<Bound<'a, PyAny>> {
132-
create_deferred(py, async move {
221+
create_deferred(py, self.reactor.bind(py), async move {
133222
let response = builder.send().await.context("sending request")?;
134223

135224
let status = response.status();
@@ -159,43 +248,51 @@ impl HttpClient {
159248
/// tokio runtime.
160249
///
161250
/// Does not handle deferred cancellation or contextvars.
162-
fn create_deferred<F, O>(py: Python, fut: F) -> PyResult<Bound<'_, PyAny>>
251+
fn create_deferred<'py, F, O>(
252+
py: Python<'py>,
253+
reactor: &Bound<'py, PyAny>,
254+
fut: F,
255+
) -> PyResult<Bound<'py, PyAny>>
163256
where
164257
F: Future<Output = PyResult<O>> + Send + 'static,
165-
for<'a> O: IntoPyObject<'a>,
258+
for<'a> O: IntoPyObject<'a> + Send + 'static,
166259
{
167-
let deferred = DEFERRED_CLASS.bind(py).call0()?;
260+
let deferred = defer(py)?.call_method0("Deferred")?;
168261
let deferred_callback = deferred.getattr("callback")?.unbind();
169262
let deferred_errback = deferred.getattr("errback")?.unbind();
170263

171-
RUNTIME.spawn(async move {
172-
// TODO: Is it safe to assert unwind safety here? I think so, as we
173-
// don't use anything that could be tainted by the panic afterwards.
174-
// Note that `.spawn(..)` asserts unwind safety on the future too.
175-
let res = AssertUnwindSafe(fut).catch_unwind().await;
264+
let rt = runtime(reactor)?;
265+
let handle = rt.handle()?;
266+
let task = handle.spawn(fut);
267+
268+
// Unbind the reactor so that we can pass it to the task
269+
let reactor = reactor.clone().unbind();
270+
handle.spawn(async move {
271+
let res = task.await;
176272

177273
Python::with_gil(move |py| {
178274
// Flatten the panic into standard python error
179275
let res = match res {
180276
Ok(r) => r,
181-
Err(panic_err) => {
182-
let panic_message = get_panic_message(&panic_err);
183-
Err(PyException::new_err(
184-
PyString::new(py, panic_message).unbind(),
185-
))
186-
}
277+
Err(join_err) => match join_err.try_into_panic() {
278+
Ok(panic_err) => Err(RustPanicError::from_panic(&panic_err)),
279+
Err(err) => Err(PyException::new_err(format!("Task cancelled: {err}"))),
280+
},
187281
};
188282

283+
// Re-bind the reactor
284+
let reactor = reactor.bind(py);
285+
189286
// Send the result to the deferred, via `.callback(..)` or `.errback(..)`
190287
match res {
191288
Ok(obj) => {
192-
TWISTED_REACTOR
193-
.call_method(py, "callFromThread", (deferred_callback, obj), None)
289+
reactor
290+
.call_method("callFromThread", (deferred_callback, obj), None)
194291
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
195292
}
196293
Err(err) => {
197-
TWISTED_REACTOR
198-
.call_method(py, "callFromThread", (deferred_errback, err), None)
294+
reactor
295+
.call_method("callFromThread", (deferred_errback, err), None)
199296
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
200297
}
201298
}
@@ -204,15 +301,3 @@ where
204301

205302
Ok(deferred)
206303
}
207-
208-
/// Try and get the panic message out of the panic
209-
fn get_panic_message<'a>(panic_err: &'a (dyn std::any::Any + Send + 'static)) -> &'a str {
210-
// Apparently this is how you extract the panic message from a panic
211-
if let Some(str_slice) = panic_err.downcast_ref::<&str>() {
212-
str_slice
213-
} else if let Some(string) = panic_err.downcast_ref::<String>() {
214-
string
215-
} else {
216-
"unknown error"
217-
}
218-
}

synapse/api/auth/msc3861_delegated.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def __init__(self, hs: "HomeServer"):
184184
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
185185

186186
self._rust_http_client = HttpClient(
187-
user_agent=self._http_client.user_agent.decode("utf8")
187+
reactor=hs.get_reactor(),
188+
user_agent=self._http_client.user_agent.decode("utf8"),
188189
)
189190

190191
# # Token Introspection Cache

synapse/synapse_rust/http_client.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212

1313
from typing import Awaitable, Mapping
1414

15+
from synapse.types import ISynapseReactor
16+
1517
class HttpClient:
16-
def __init__(self, user_agent: str) -> None: ...
18+
def __init__(self, reactor: ISynapseReactor, user_agent: str) -> None: ...
1719
def get(self, url: str, response_limit: int) -> Awaitable[bytes]: ...
1820
def post(
1921
self,

0 commit comments

Comments
 (0)