12
12
* <https://www.gnu.org/licenses/agpl-3.0.html>.
13
13
*/
14
14
15
- use std:: { collections:: HashMap , future:: Future , panic :: AssertUnwindSafe , sync :: LazyLock } ;
15
+ use std:: { collections:: HashMap , future:: Future } ;
16
16
17
17
use anyhow:: Context ;
18
- use futures:: { FutureExt , TryStreamExt } ;
19
- use pyo3:: { exceptions:: PyException , prelude:: * , types:: PyString } ;
18
+ use futures:: TryStreamExt ;
19
+ use once_cell:: sync:: OnceCell ;
20
+ use pyo3:: { create_exception, exceptions:: PyException , prelude:: * } ;
20
21
use reqwest:: RequestBuilder ;
21
22
use tokio:: runtime:: Runtime ;
22
23
23
24
use crate :: errors:: HttpResponseException ;
24
25
25
- /// The tokio runtime that we're using to run async Rust libs.
26
- static RUNTIME : LazyLock < Runtime > = LazyLock :: new ( || {
27
- tokio:: runtime:: Builder :: new_multi_thread ( )
28
- . worker_threads ( 4 )
29
- . enable_all ( )
30
- . build ( )
31
- . unwrap ( )
32
- } ) ;
33
-
34
- /// A reference to the `Deferred` python class.
35
- static DEFERRED_CLASS : LazyLock < PyObject > = LazyLock :: new ( || {
36
- Python :: with_gil ( |py| {
37
- py. import ( "twisted.internet.defer" )
38
- . expect ( "module 'twisted.internet.defer' should be importable" )
39
- . getattr ( "Deferred" )
40
- . expect ( "module 'twisted.internet.defer' should have a 'Deferred' class" )
41
- . unbind ( )
42
- } )
43
- } ) ;
44
-
45
- /// A reference to the twisted `reactor`.
46
- static TWISTED_REACTOR : LazyLock < Py < PyModule > > = LazyLock :: new ( || {
47
- Python :: with_gil ( |py| {
48
- py. import ( "twisted.internet.reactor" )
49
- . expect ( "module 'twisted.internet.reactor' should be importable" )
50
- . unbind ( )
51
- } )
52
- } ) ;
26
+ create_exception ! (
27
+ synapse. synapse_rust. http_client,
28
+ RustPanicError ,
29
+ PyException ,
30
+ "A panic which happened in a Rust future"
31
+ ) ;
32
+
33
+ impl RustPanicError {
34
+ fn from_panic ( panic_err : & ( dyn std:: any:: Any + Send + ' static ) ) -> PyErr {
35
+ // Apparently this is how you extract the panic message from a panic
36
+ let panic_message = if let Some ( str_slice) = panic_err. downcast_ref :: < & str > ( ) {
37
+ str_slice
38
+ } else if let Some ( string) = panic_err. downcast_ref :: < String > ( ) {
39
+ string
40
+ } else {
41
+ "unknown error"
42
+ } ;
43
+ Self :: new_err ( panic_message. to_owned ( ) )
44
+ }
45
+ }
46
+
47
+ /// This is the name of the attribute where we store the runtime on the reactor
48
+ static TOKIO_RUNTIME_ATTR : & str = "__synapse_rust_tokio_runtime" ;
49
+
50
+ /// A Python wrapper around a Tokio runtime.
51
+ ///
52
+ /// This allows us to 'store' the runtime on the reactor instance, starting it
53
+ /// when the reactor starts, and stopping it when the reactor shuts down.
54
+ #[ pyclass]
55
+ struct PyTokioRuntime {
56
+ runtime : Option < Runtime > ,
57
+ }
58
+
59
+ #[ pymethods]
60
+ impl PyTokioRuntime {
61
+ fn start ( & mut self ) -> PyResult < ( ) > {
62
+ // TODO: allow customization of the runtime like the number of threads
63
+ let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
64
+ . worker_threads ( 4 )
65
+ . enable_all ( )
66
+ . build ( ) ?;
67
+
68
+ self . runtime = Some ( runtime) ;
69
+
70
+ Ok ( ( ) )
71
+ }
72
+
73
+ fn shutdown ( & mut self ) -> PyResult < ( ) > {
74
+ let runtime = self
75
+ . runtime
76
+ . take ( )
77
+ . context ( "Runtime was already shutdown" ) ?;
78
+
79
+ // Dropping the runtime will shut it down
80
+ drop ( runtime) ;
81
+
82
+ Ok ( ( ) )
83
+ }
84
+ }
85
+
86
+ impl PyTokioRuntime {
87
+ /// Get the handle to the Tokio runtime, if it is running.
88
+ fn handle ( & self ) -> PyResult < & tokio:: runtime:: Handle > {
89
+ let handle = self
90
+ . runtime
91
+ . as_ref ( )
92
+ . context ( "Tokio runtime is not running" ) ?
93
+ . handle ( ) ;
94
+
95
+ Ok ( handle)
96
+ }
97
+ }
98
+
99
+ /// Get a handle to the Tokio runtime stored on the reactor instance, or create
100
+ /// a new one.
101
+ fn runtime < ' a > ( reactor : & Bound < ' a , PyAny > ) -> PyResult < PyRef < ' a , PyTokioRuntime > > {
102
+ if !reactor. hasattr ( TOKIO_RUNTIME_ATTR ) ? {
103
+ install_runtime ( reactor) ?;
104
+ }
105
+
106
+ get_runtime ( reactor)
107
+ }
108
+
109
+ /// Install a new Tokio runtime on the reactor instance.
110
+ fn install_runtime ( reactor : & Bound < PyAny > ) -> PyResult < ( ) > {
111
+ let py = reactor. py ( ) ;
112
+ let runtime = PyTokioRuntime { runtime : None } ;
113
+ let runtime = runtime. into_pyobject ( py) ?;
114
+
115
+ // Attach the runtime to the reactor, starting it when the reactor is
116
+ // running, stopping it when the reactor is shutting down
117
+ reactor. call_method1 ( "callWhenRunning" , ( runtime. getattr ( "start" ) ?, ) ) ?;
118
+ reactor. call_method1 (
119
+ "addSystemEventTrigger" ,
120
+ ( "after" , "shutdown" , runtime. getattr ( "shutdown" ) ?) ,
121
+ ) ?;
122
+ reactor. setattr ( TOKIO_RUNTIME_ATTR , runtime) ?;
123
+
124
+ Ok ( ( ) )
125
+ }
126
+
127
+ /// Get a reference to a Tokio runtime handle stored on the reactor instance.
128
+ fn get_runtime < ' a > ( reactor : & Bound < ' a , PyAny > ) -> PyResult < PyRef < ' a , PyTokioRuntime > > {
129
+ // This will raise if `TOKIO_RUNTIME_ATTR` is not set or if it is
130
+ // not a `Runtime`. Careful that this could happen if the user sets it
131
+ // manually, or if multiple versions of `pyo3-twisted` are used!
132
+ let runtime: Bound < PyTokioRuntime > = reactor. getattr ( TOKIO_RUNTIME_ATTR ) ?. extract ( ) ?;
133
+ Ok ( runtime. borrow ( ) )
134
+ }
135
+
136
+ /// A reference to the `twisted.internet.defer` module.
137
+ static DEFER : OnceCell < PyObject > = OnceCell :: new ( ) ;
138
+
139
+ /// Access to the `twisted.internet.defer` module.
140
+ fn defer ( py : Python < ' _ > ) -> PyResult < & Bound < PyAny > > {
141
+ Ok ( DEFER
142
+ . get_or_try_init ( || py. import ( "twisted.internet.defer" ) . map ( Into :: into) ) ?
143
+ . bind ( py) )
144
+ }
53
145
54
146
/// Called when registering modules with python.
55
147
pub fn register_module ( py : Python < ' _ > , m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
56
148
let child_module: Bound < ' _ , PyModule > = PyModule :: new ( py, "http_client" ) ?;
57
149
child_module. add_class :: < HttpClient > ( ) ?;
58
150
59
- // Make sure we fail early if we can't build the lazy statics.
60
- LazyLock :: force ( & RUNTIME ) ;
61
- LazyLock :: force ( & DEFERRED_CLASS ) ;
151
+ // Make sure we fail early if we can't load some modules
152
+ defer ( py) ?;
62
153
63
154
m. add_submodule ( & child_module) ?;
64
155
65
156
// We need to manually add the module to sys.modules to make `from
66
- // synapse.synapse_rust import acl ` work.
157
+ // synapse.synapse_rust import http_client ` work.
67
158
py. import ( "sys" ) ?
68
159
. getattr ( "modules" ) ?
69
160
. set_item ( "synapse.synapse_rust.http_client" , child_module) ?;
@@ -72,26 +163,24 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
72
163
}
73
164
74
165
#[ pyclass]
75
- #[ derive( Clone ) ]
76
166
struct HttpClient {
77
167
client : reqwest:: Client ,
168
+ reactor : PyObject ,
78
169
}
79
170
80
171
#[ pymethods]
81
172
impl HttpClient {
82
173
#[ new]
83
- pub fn py_new ( user_agent : & str ) -> PyResult < HttpClient > {
84
- // The twisted reactor can only be imported after Synapse has been
85
- // imported, to allow Synapse to change the twisted reactor. If we try
86
- // and import the reactor too early twisted installs a default reactor,
87
- // which can't be replaced.
88
- LazyLock :: force ( & TWISTED_REACTOR ) ;
174
+ pub fn py_new ( reactor : Bound < PyAny > , user_agent : & str ) -> PyResult < HttpClient > {
175
+ // Make sure the runtime gets installed
176
+ let _ = runtime ( & reactor) ?;
89
177
90
178
Ok ( HttpClient {
91
179
client : reqwest:: Client :: builder ( )
92
180
. user_agent ( user_agent)
93
181
. build ( )
94
182
. context ( "building reqwest client" ) ?,
183
+ reactor : reactor. unbind ( ) ,
95
184
} )
96
185
}
97
186
@@ -129,7 +218,7 @@ impl HttpClient {
129
218
builder : RequestBuilder ,
130
219
response_limit : usize ,
131
220
) -> PyResult < Bound < ' a , PyAny > > {
132
- create_deferred ( py, async move {
221
+ create_deferred ( py, self . reactor . bind ( py ) , async move {
133
222
let response = builder. send ( ) . await . context ( "sending request" ) ?;
134
223
135
224
let status = response. status ( ) ;
@@ -159,43 +248,51 @@ impl HttpClient {
159
248
/// tokio runtime.
160
249
///
161
250
/// Does not handle deferred cancellation or contextvars.
162
- fn create_deferred < F , O > ( py : Python , fut : F ) -> PyResult < Bound < ' _ , PyAny > >
251
+ fn create_deferred < ' py , F , O > (
252
+ py : Python < ' py > ,
253
+ reactor : & Bound < ' py , PyAny > ,
254
+ fut : F ,
255
+ ) -> PyResult < Bound < ' py , PyAny > >
163
256
where
164
257
F : Future < Output = PyResult < O > > + Send + ' static ,
165
- for < ' a > O : IntoPyObject < ' a > ,
258
+ for < ' a > O : IntoPyObject < ' a > + Send + ' static ,
166
259
{
167
- let deferred = DEFERRED_CLASS . bind ( py) . call0 ( ) ?;
260
+ let deferred = defer ( py) ? . call_method0 ( "Deferred" ) ?;
168
261
let deferred_callback = deferred. getattr ( "callback" ) ?. unbind ( ) ;
169
262
let deferred_errback = deferred. getattr ( "errback" ) ?. unbind ( ) ;
170
263
171
- RUNTIME . spawn ( async move {
172
- // TODO: Is it safe to assert unwind safety here? I think so, as we
173
- // don't use anything that could be tainted by the panic afterwards.
174
- // Note that `.spawn(..)` asserts unwind safety on the future too.
175
- let res = AssertUnwindSafe ( fut) . catch_unwind ( ) . await ;
264
+ let rt = runtime ( reactor) ?;
265
+ let handle = rt. handle ( ) ?;
266
+ let task = handle. spawn ( fut) ;
267
+
268
+ // Unbind the reactor so that we can pass it to the task
269
+ let reactor = reactor. clone ( ) . unbind ( ) ;
270
+ handle. spawn ( async move {
271
+ let res = task. await ;
176
272
177
273
Python :: with_gil ( move |py| {
178
274
// Flatten the panic into standard python error
179
275
let res = match res {
180
276
Ok ( r) => r,
181
- Err ( panic_err) => {
182
- let panic_message = get_panic_message ( & panic_err) ;
183
- Err ( PyException :: new_err (
184
- PyString :: new ( py, panic_message) . unbind ( ) ,
185
- ) )
186
- }
277
+ Err ( join_err) => match join_err. try_into_panic ( ) {
278
+ Ok ( panic_err) => Err ( RustPanicError :: from_panic ( & panic_err) ) ,
279
+ Err ( err) => Err ( PyException :: new_err ( format ! ( "Task cancelled: {err}" ) ) ) ,
280
+ } ,
187
281
} ;
188
282
283
+ // Re-bind the reactor
284
+ let reactor = reactor. bind ( py) ;
285
+
189
286
// Send the result to the deferred, via `.callback(..)` or `.errback(..)`
190
287
match res {
191
288
Ok ( obj) => {
192
- TWISTED_REACTOR
193
- . call_method ( py , "callFromThread" , ( deferred_callback, obj) , None )
289
+ reactor
290
+ . call_method ( "callFromThread" , ( deferred_callback, obj) , None )
194
291
. expect ( "callFromThread should not fail" ) ; // There's nothing we can really do with errors here
195
292
}
196
293
Err ( err) => {
197
- TWISTED_REACTOR
198
- . call_method ( py , "callFromThread" , ( deferred_errback, err) , None )
294
+ reactor
295
+ . call_method ( "callFromThread" , ( deferred_errback, err) , None )
199
296
. expect ( "callFromThread should not fail" ) ; // There's nothing we can really do with errors here
200
297
}
201
298
}
@@ -204,15 +301,3 @@ where
204
301
205
302
Ok ( deferred)
206
303
}
207
-
208
- /// Try and get the panic message out of the panic
209
- fn get_panic_message < ' a > ( panic_err : & ' a ( dyn std:: any:: Any + Send + ' static ) ) -> & ' a str {
210
- // Apparently this is how you extract the panic message from a panic
211
- if let Some ( str_slice) = panic_err. downcast_ref :: < & str > ( ) {
212
- str_slice
213
- } else if let Some ( string) = panic_err. downcast_ref :: < String > ( ) {
214
- string
215
- } else {
216
- "unknown error"
217
- }
218
- }
0 commit comments