Skip to content

Commit 7d2dc2d

Browse files
fix: async handler bridge + mount_points setter + cleanup coroutine handling
Three dogfooding bug fixes: 1. Added mount_points setter on PyCoordinator (Foundation needs to write to it) 2. PyHookHandlerBridge now properly awaits async Python handlers using run_coroutine_threadsafe 3. Cleanup coroutine handling uses run_coroutine_threadsafe instead of run_until_complete (which fails inside running event loops) 🤖 Generated with [Amplifier](https://github.com/microsoft/amplifier) Co-Authored-By: Amplifier <240397093+microsoft-amplifier@users.noreply.github.com>
1 parent 4c1da1c commit 7d2dc2d

File tree

1 file changed

+89
-30
lines changed

1 file changed

+89
-30
lines changed

bindings/python/src/lib.rs

Lines changed: 89 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,39 +49,82 @@ impl HookHandler for PyHookHandlerBridge {
4949
&self,
5050
event: &str,
5151
data: Value,
52-
) -> Pin<Box<dyn Future<Output = Result<HookResult, HookError>> + Send + '_>> {
52+
) -> std::pin::Pin<
53+
Box<dyn std::future::Future<Output = Result<HookResult, HookError>> + Send + '_>,
54+
> {
5355
let event = event.to_string();
54-
let data_str = serde_json::to_string(&data).unwrap_or_else(|_| "{}".to_string());
55-
56+
// Clone the Py<PyAny> reference inside the GIL to safely move into async block
57+
let callable = Python::try_attach(|py| {
58+
Ok::<_, PyErr>(self.callable.clone_ref(py))
59+
}).unwrap().unwrap();
5660
Box::pin(async move {
57-
// Acquire GIL to call the Python callable.
58-
// Python::try_attach is the PyO3 0.28 way to get the GIL.
59-
let result = Python::try_attach(|py| -> PyResult<HookResult> {
61+
// Call the Python handler and handle both sync and async returns
62+
let result_json: String = Python::try_attach(|py| -> PyResult<String> {
6063
let json_mod = py.import("json")?;
64+
let data_str = serde_json::to_string(&data).unwrap_or_else(|_| "{}".to_string());
6165
let py_data = json_mod.call_method1("loads", (&data_str,))?;
6266

63-
let result = self.callable.call(py, (&event, py_data), None)?;
64-
65-
// If the callable returns None, treat as continue
66-
if result.is_none(py) {
67-
return Ok(HookResult::default());
67+
let call_result = callable.call(py, (&event, py_data), None)?;
68+
let bound = call_result.bind(py);
69+
70+
// Check if the result is a coroutine (async handler)
71+
let inspect = py.import("inspect")?;
72+
let is_coro: bool = inspect.call_method1("iscoroutine", (bound,))?.extract()?;
73+
74+
if is_coro {
75+
// Await the coroutine using asyncio
76+
let asyncio = py.import("asyncio")?;
77+
// Try to get the running loop and create a task
78+
// If we're in an async context, use ensure_future + loop.run_until_complete
79+
match asyncio.call_method1("get_running_loop", ()) {
80+
Ok(loop_) => {
81+
// We're inside a running loop — we can't run_until_complete.
82+
// Instead, use a thread to run the coroutine.
83+
// But for simplicity, let's try the concurrent.futures approach
84+
let concurrent = py.import("concurrent.futures")?;
85+
let thread_pool = concurrent.getattr("ThreadPoolExecutor")?.call1((1,))?;
86+
let future = asyncio.call_method1("run_coroutine_threadsafe", (bound, &loop_))?;
87+
let awaited = future.call_method1("result", (5.0,))?; // 5s timeout
88+
drop(thread_pool);
89+
90+
if awaited.is_none() {
91+
return Ok("{}".to_string());
92+
}
93+
let json_str: String = json_mod.call_method1("dumps", (&awaited,))?
94+
.extract()
95+
.unwrap_or_else(|_| "{}".to_string());
96+
Ok(json_str)
97+
}
98+
Err(_) => {
99+
// No running loop — use asyncio.run() in a new loop
100+
let awaited = asyncio.call_method1("run", (bound,))?;
101+
if awaited.is_none() {
102+
return Ok("{}".to_string());
103+
}
104+
let json_str: String = json_mod.call_method1("dumps", (&awaited,))?
105+
.extract()
106+
.unwrap_or_else(|_| "{}".to_string());
107+
Ok(json_str)
108+
}
109+
}
110+
} else {
111+
// Sync handler — process the result directly
112+
if bound.is_none() {
113+
return Ok("{}".to_string());
114+
}
115+
let json_str: String = json_mod.call_method1("dumps", (bound,))?
116+
.extract()
117+
.unwrap_or_else(|_| "{}".to_string());
118+
Ok(json_str)
68119
}
120+
})
121+
.ok_or_else(|| HookError::HandlerFailed { message: "Failed to attach to Python runtime".to_string(), handler_name: None })?
122+
.map_err(|e| HookError::HandlerFailed { message: format!("Python handler error: {e}"), handler_name: None })?;
69123

70-
// For any non-None return, default to continue
71-
// TODO(milestone-6): Parse dict result into full HookResult
72-
Ok(HookResult::default())
73-
});
74-
75-
match result {
76-
Some(Ok(hook_result)) => Ok(hook_result),
77-
Some(Err(py_err)) => Err(HookError::Other {
78-
message: format!("Python hook handler error: {py_err}"),
79-
}),
80-
None => {
81-
// No Python interpreter attached — return default
82-
Ok(HookResult::default())
83-
}
84-
}
124+
// Parse the JSON result into a HookResult
125+
let hook_result: HookResult = serde_json::from_str(&result_json)
126+
.unwrap_or_default();
127+
Ok(hook_result)
85128
})
86129
}
87130
}
@@ -822,6 +865,12 @@ impl PyCoordinator {
822865
Ok(self.mount_points.bind(py).clone())
823866
}
824867

868+
#[setter]
869+
fn set_mount_points(&mut self, _py: Python<'_>, value: Bound<'_, PyDict>) -> PyResult<()> {
870+
self.mount_points = value.unbind();
871+
Ok(())
872+
}
873+
825874
// -----------------------------------------------------------------------
826875
// Task 2.2: mount() and get()
827876
// -----------------------------------------------------------------------
@@ -1079,15 +1128,25 @@ impl PyCoordinator {
10791128
// Try calling; catch and log errors
10801129
match cleanup_fn.call0() {
10811130
Ok(result) => {
1082-
// If it returned a coroutine, we need to handle it
1131+
// If it returned a coroutine, await it properly
10831132
let inspect = py.import("inspect")?;
10841133
let is_coro: bool =
10851134
inspect.call_method1("iscoroutine", (&result,))?.extract()?;
10861135
if is_coro {
1087-
// Run the coroutine in the event loop
10881136
let asyncio = py.import("asyncio")?;
1089-
let _ = asyncio.call_method1("get_event_loop", ())
1090-
.and_then(|loop_| loop_.call_method1("run_until_complete", (&result,)));
1137+
// Try to schedule in the running loop
1138+
match asyncio.call_method1("get_running_loop", ()) {
1139+
Ok(loop_) => {
1140+
let future = asyncio.call_method1(
1141+
"run_coroutine_threadsafe", (&result, &loop_)
1142+
)?;
1143+
let _ = future.call_method1("result", (5.0,));
1144+
}
1145+
Err(_) => {
1146+
// No running loop, use asyncio.run
1147+
let _ = asyncio.call_method1("run", (&result,));
1148+
}
1149+
}
10911150
}
10921151
}
10931152
Err(e) => {

0 commit comments

Comments
 (0)