Skip to content

Commit 232c834

Browse files
committed
Improve coroutine handling
1 parent a119756 commit 232c834

File tree

3 files changed

+101
-112
lines changed

3 files changed

+101
-112
lines changed

src/asgi/mod.rs

Lines changed: 68 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,27 @@ pub use websocket::{
2626
WebSocketConnectionScope, WebSocketReceiveMessage, WebSocketSendException, WebSocketSendMessage,
2727
};
2828

29+
fn execute_coroutine(
30+
py_func: PyObject,
31+
scope: HttpConnectionScope,
32+
rx_receiver: Receiver,
33+
tx_sender: Sender,
34+
) -> PyResult<()> {
35+
Python::with_gil(|py| {
36+
let scope_py = scope.into_pyobject(py)?;
37+
let coroutine = py_func.call1(py, (scope_py, rx_receiver, tx_sender))?;
38+
39+
let asyncio = py.import("asyncio")?;
40+
let loop_ = asyncio.call_method0("new_event_loop")?;
41+
asyncio.call_method1("set_event_loop", (&loop_,))?;
42+
43+
loop_.call_method1("run_until_complete", (coroutine,))?;
44+
loop_.call_method0("close")?;
45+
46+
Ok(())
47+
})
48+
}
49+
2950
pub async fn execute_asgi_http_scope(
3051
py_func: PyObject,
3152
request: Request,
@@ -38,7 +59,7 @@ pub async fn execute_asgi_http_scope(
3859
let (tx_sender, mut tx) = Sender::http();
3960

4061
// Channel to receive the response data
41-
let (response_tx, response_rx) = oneshot::channel::<(u16, Vec<(String, String)>, Vec<u8>)>();
62+
let (response_tx, response_rx) = oneshot::channel::<Result<(u16, Vec<(String, String)>, Vec<u8>), String>>();
4263

4364
// Channel to signal Python execution completion
4465
let (python_done_tx, python_done_rx) = oneshot::channel::<Result<(), String>>();
@@ -50,115 +71,83 @@ pub async fn execute_asgi_http_scope(
5071
let mut body = Vec::new();
5172
let mut response_started = false;
5273

53-
while let Some(msg) = tx.recv().await {
54-
match msg {
55-
HttpSendMessage::HttpResponseStart {
56-
status: s,
57-
headers: h,
58-
trailers: _
59-
} => {
60-
status = s;
61-
headers = h;
62-
response_started = true;
63-
}
64-
HttpSendMessage::HttpResponseBody {
65-
body: b,
66-
more_body
67-
} => {
68-
if response_started {
69-
body.extend_from_slice(&b);
70-
if !more_body {
71-
// Response is complete
72-
let _ = response_tx.send((status, headers, body));
73-
break;
74+
while let Some(msg_result) = tx.recv().await {
75+
match msg_result {
76+
Ok(msg) => {
77+
match msg {
78+
HttpSendMessage::HttpResponseStart {
79+
status: s,
80+
headers: h,
81+
trailers: _
82+
} => {
83+
status = s;
84+
headers = h;
85+
response_started = true;
86+
}
87+
HttpSendMessage::HttpResponseBody {
88+
body: b,
89+
more_body
90+
} => {
91+
if response_started {
92+
body.extend_from_slice(&b);
93+
if !more_body {
94+
// Response is complete
95+
let _ = response_tx.send(Ok((status, headers, body)));
96+
break;
97+
}
98+
}
7499
}
75100
}
76101
}
102+
Err(e) => {
103+
// Error from Python side
104+
let _ = response_tx.send(Err(e));
105+
break;
106+
}
77107
}
78108
}
79109
});
80110

81111
// Send the request body to the ASGI app
82112
let request_body = request.body().clone();
113+
let rx_for_request = rx.clone();
83114
tokio::spawn(async move {
84115
let request_message = HttpReceiveMessage::Request {
85116
body: request_body.to_vec(),
86117
more_body: false,
87118
};
88119

89-
if rx.send(request_message).is_err() {
90-
eprintln!("Failed to send request message");
120+
if rx_for_request.send(Ok(request_message)).is_err() {
121+
// Channel closed, but we can't do much about it here
122+
// The error will be handled by the main flow
91123
}
92124
});
93125

94126
// Run the ASGI app in a Python thread
95127
tokio::task::spawn_blocking(move || {
96-
Python::with_gil(|py| {
97-
let scope_py = scope.into_pyobject(py).unwrap();
98-
99-
// Call the ASGI app to get a coroutine
100-
let coroutine = match py_func.call1(py, (scope_py, rx_receiver, tx_sender)) {
101-
Ok(coro) => coro,
102-
Err(e) => {
103-
eprintln!("Failed to call ASGI app: {e}");
104-
let _ = python_done_tx.send(Err(format!("Failed to call ASGI app: {e}")));
105-
return;
106-
}
107-
};
128+
let result = execute_coroutine(py_func, scope, rx_receiver, tx_sender)
129+
.map_err(|e| format!("Failed to execute ASGI app: {e}"));
108130

109-
// Run the coroutine using asyncio
110-
let asyncio = match py.import("asyncio") {
111-
Ok(module) => module,
112-
Err(e) => {
113-
eprintln!("Failed to import asyncio: {e}");
114-
let _ = python_done_tx.send(Err(format!("Failed to import asyncio: {e}")));
115-
return;
116-
}
117-
};
118-
119-
// Create a new event loop for this request
120-
let loop_ = match asyncio.call_method0("new_event_loop") {
121-
Ok(loop_) => loop_,
122-
Err(e) => {
123-
eprintln!("Failed to create event loop: {e}");
124-
let _ = python_done_tx.send(Err(format!("Failed to create event loop: {e}")));
125-
return;
126-
}
127-
};
128-
129-
// Set it as the current event loop
130-
if let Err(e) = asyncio.call_method1("set_event_loop", (&loop_,)) {
131-
eprintln!("Failed to set event loop: {e}");
132-
let _ = python_done_tx.send(Err(format!("Failed to set event loop: {e}")));
133-
return;
134-
}
135-
136-
// Run the coroutine
137-
let result = loop_.call_method1("run_until_complete", (coroutine,));
138-
139-
// Close the loop
140-
let _ = loop_.call_method0("close");
141-
142-
// Send the result after a small delay to allow response to be processed
143-
// This ensures the response handler has time to receive messages from the ASGI app
144-
std::thread::sleep(std::time::Duration::from_millis(10));
145-
146-
let _ = python_done_tx.send(result.map(|_| ()).map_err(|e| {
147-
format!("Failed to run ASGI coroutine: {e}")
148-
}));
149-
});
131+
// If the coroutine execution was successful, we send a success signal
132+
// Otherwise, we send the error
133+
if let Err(err) = python_done_tx.send(result) {
134+
eprintln!("Send error: {err:?}");
135+
}
150136
});
151137

152138
// Wait for either the response or Python completion
153139
let result = tokio::select! {
154140
response = response_rx => {
155-
response.map_err(|_| pyo3::exceptions::PyRuntimeError::new_err("Failed to receive response"))
141+
match response {
142+
Ok(Ok(response_data)) => Ok(response_data),
143+
Ok(Err(e)) => Err(pyo3::exceptions::PyRuntimeError::new_err(e)),
144+
Err(_) => Err(pyo3::exceptions::PyRuntimeError::new_err("Failed to receive response"))
145+
}
156146
}
157147
python_result = python_done_rx => {
158148
match python_result {
159149
Ok(Ok(())) => {
160150
// Python completed but no response was sent
161-
eprintln!("ASGI app completed without sending response");
162151
Err(pyo3::exceptions::PyRuntimeError::new_err("ASGI app completed without sending response"))
163152
}
164153
Ok(Err(e)) => {

src/asgi/receiver.rs

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ use crate::asgi::{
1111
};
1212

1313
enum ReceiverType {
14-
Http(Arc<Mutex<mpsc::UnboundedReceiver<HttpReceiveMessage>>>),
15-
WebSocket(Arc<Mutex<mpsc::UnboundedReceiver<WebSocketReceiveMessage>>>),
16-
Lifespan(Arc<Mutex<mpsc::UnboundedReceiver<LifespanReceiveMessage>>>),
14+
Http(Arc<Mutex<mpsc::UnboundedReceiver<Result<HttpReceiveMessage, String>>>>),
15+
WebSocket(Arc<Mutex<mpsc::UnboundedReceiver<Result<WebSocketReceiveMessage, String>>>>),
16+
Lifespan(Arc<Mutex<mpsc::UnboundedReceiver<Result<LifespanReceiveMessage, String>>>>),
1717
}
1818

1919
/// Allows Python to receive messages from Rust.
@@ -22,22 +22,22 @@ pub struct Receiver(ReceiverType);
2222

2323
impl Receiver {
2424
/// Create a new Receiver instance for http ASGI message types.
25-
pub fn http() -> (Receiver, mpsc::UnboundedSender<HttpReceiveMessage>) {
26-
let (tx, rx) = mpsc::unbounded_channel::<HttpReceiveMessage>();
25+
pub fn http() -> (Receiver, mpsc::UnboundedSender<Result<HttpReceiveMessage, String>>) {
26+
let (tx, rx) = mpsc::unbounded_channel::<Result<HttpReceiveMessage, String>>();
2727
let rx = Arc::new(Mutex::new(rx));
2828
(Receiver(ReceiverType::Http(rx)), tx)
2929
}
3030

3131
/// Create a new Receiver instance for websocket ASGI message types.
32-
pub fn websocket() -> (Receiver, mpsc::UnboundedSender<WebSocketReceiveMessage>) {
33-
let (tx, rx) = mpsc::unbounded_channel::<WebSocketReceiveMessage>();
32+
pub fn websocket() -> (Receiver, mpsc::UnboundedSender<Result<WebSocketReceiveMessage, String>>) {
33+
let (tx, rx) = mpsc::unbounded_channel::<Result<WebSocketReceiveMessage, String>>();
3434
let rx = Arc::new(Mutex::new(rx));
3535
(Receiver(ReceiverType::WebSocket(rx)), tx)
3636
}
3737

3838
/// Create a new Receiver instance for lifespan ASGI message types.
39-
pub fn lifespan() -> (Receiver, mpsc::UnboundedSender<LifespanReceiveMessage>) {
40-
let (tx, rx) = mpsc::unbounded_channel::<LifespanReceiveMessage>();
39+
pub fn lifespan() -> (Receiver, mpsc::UnboundedSender<Result<LifespanReceiveMessage, String>>) {
40+
let (tx, rx) = mpsc::unbounded_channel::<Result<LifespanReceiveMessage, String>>();
4141
let rx = Arc::new(Mutex::new(rx));
4242
(Receiver(ReceiverType::Lifespan(rx)), tx)
4343
}
@@ -49,26 +49,26 @@ impl Receiver {
4949
match &self.0 {
5050
ReceiverType::Http(rx) => {
5151
let message = rx.lock().await.recv().await;
52-
if let Some(msg) = message {
53-
Python::with_gil(|py| Ok(msg.into_pyobject(py)?.unbind()))
54-
} else {
55-
Err(PyValueError::new_err("No message received"))
52+
match message {
53+
Some(Ok(msg)) => Python::with_gil(|py| Ok(msg.into_pyobject(py)?.unbind())),
54+
Some(Err(err)) => Err(PyValueError::new_err(err)),
55+
None => Err(PyValueError::new_err("No message received"))
5656
}
5757
}
5858
ReceiverType::WebSocket(rx) => {
5959
let message = rx.lock().await.recv().await;
60-
if let Some(msg) = message {
61-
Python::with_gil(|py| Ok(msg.into_pyobject(py)?.unbind()))
62-
} else {
63-
Err(PyValueError::new_err("No message received"))
60+
match message {
61+
Some(Ok(msg)) => Python::with_gil(|py| Ok(msg.into_pyobject(py)?.unbind())),
62+
Some(Err(err)) => Err(PyValueError::new_err(err)),
63+
None => Err(PyValueError::new_err("No message received"))
6464
}
6565
}
6666
ReceiverType::Lifespan(rx) => {
6767
let message = rx.lock().await.recv().await;
68-
if let Some(msg) = message {
69-
Python::with_gil(|py| Ok(msg.into_pyobject(py)?.unbind()))
70-
} else {
71-
Err(PyValueError::new_err("No message received"))
68+
match message {
69+
Some(Ok(msg)) => Python::with_gil(|py| Ok(msg.into_pyobject(py)?.unbind())),
70+
Some(Err(err)) => Err(PyValueError::new_err(err)),
71+
None => Err(PyValueError::new_err("No message received"))
7272
}
7373
}
7474
}

src/asgi/sender.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,28 @@ use crate::asgi::{
99
};
1010

1111
enum SenderType {
12-
Http(mpsc::UnboundedSender<HttpSendMessage>),
13-
WebSocket(mpsc::UnboundedSender<WebSocketSendMessage>),
14-
Lifespan(mpsc::UnboundedSender<LifespanSendMessage>),
12+
Http(mpsc::UnboundedSender<Result<HttpSendMessage, String>>),
13+
WebSocket(mpsc::UnboundedSender<Result<WebSocketSendMessage, String>>),
14+
Lifespan(mpsc::UnboundedSender<Result<LifespanSendMessage, String>>),
1515
}
1616

1717
/// Allows Python to send messages to Rust.
1818
#[pyclass]
1919
pub struct Sender(SenderType);
2020

2121
impl Sender {
22-
pub fn http() -> (Sender, mpsc::UnboundedReceiver<HttpSendMessage>) {
23-
let (tx, rx) = mpsc::unbounded_channel::<HttpSendMessage>();
22+
pub fn http() -> (Sender, mpsc::UnboundedReceiver<Result<HttpSendMessage, String>>) {
23+
let (tx, rx) = mpsc::unbounded_channel::<Result<HttpSendMessage, String>>();
2424
(Sender(SenderType::Http(tx)), rx)
2525
}
2626

27-
pub fn websocket() -> (Sender, mpsc::UnboundedReceiver<WebSocketSendMessage>) {
28-
let (tx, rx) = mpsc::unbounded_channel::<WebSocketSendMessage>();
27+
pub fn websocket() -> (Sender, mpsc::UnboundedReceiver<Result<WebSocketSendMessage, String>>) {
28+
let (tx, rx) = mpsc::unbounded_channel::<Result<WebSocketSendMessage, String>>();
2929
(Sender(SenderType::WebSocket(tx)), rx)
3030
}
3131

32-
pub fn lifespan() -> (Sender, mpsc::UnboundedReceiver<LifespanSendMessage>) {
33-
let (tx, rx) = mpsc::unbounded_channel::<LifespanSendMessage>();
32+
pub fn lifespan() -> (Sender, mpsc::UnboundedReceiver<Result<LifespanSendMessage, String>>) {
33+
let (tx, rx) = mpsc::unbounded_channel::<Result<LifespanSendMessage, String>>();
3434
(Sender(SenderType::Lifespan(tx)), rx)
3535
}
3636
}
@@ -42,17 +42,17 @@ impl Sender {
4242
match &self.0 {
4343
SenderType::Http(tx) => {
4444
let msg: HttpSendMessage = args.extract(py)?;
45-
tx.send(msg)
45+
tx.send(Ok(msg))
4646
.map_err(|_| PyValueError::new_err("connection closed"))?;
4747
}
4848
SenderType::WebSocket(tx) => {
4949
let msg: WebSocketSendMessage = args.extract(py)?;
50-
tx.send(msg)
50+
tx.send(Ok(msg))
5151
.map_err(|_| PyValueError::new_err("connection closed"))?;
5252
}
5353
SenderType::Lifespan(tx) => {
5454
let msg: LifespanSendMessage = args.extract(py)?;
55-
tx.send(msg)
55+
tx.send(Ok(msg))
5656
.map_err(|_| PyValueError::new_err("connection closed"))?;
5757
}
5858
};

0 commit comments

Comments
 (0)