@@ -29,6 +29,20 @@ function _atpyexit()
2929 return
3030end
3131
32+
33+ const MAIN_THREAD_TASK_LOCK = ReentrantLock ()
34+ const MAIN_THREAD_CHANNEL_INPUT = Channel (1 )
35+ const MAIN_THREAD_CHANNEL_OUTPUT = Channel (1 )
36+
37+ # Execute f() on the main thread.
38+ function on_main_thread (f)
39+ @lock MAIN_THREAD_TASK_LOCK begin
40+ put! (MAIN_THREAD_CHANNEL_INPUT, f)
41+ take! (MAIN_THREAD_CHANNEL_OUTPUT)
42+ end
43+ end
44+
45+
3246function init_context ()
3347
3448 CTX. is_embedded = hasproperty (Base. Main, :__PythonCall_libptr )
@@ -240,6 +254,15 @@ function init_context()
240254 " Only Python 3.9+ is supported, this is Python $(CTX. version) at $(CTX. exe_path=== missing ? " unknown location" : CTX. exe_path) ." ,
241255 )
242256
257+ main_thread_task = Task () do
258+ while true
259+ f = take! (MAIN_THREAD_CHANNEL_INPUT)
260+ put! (MAIN_THREAD_CHANNEL_OUTPUT, f ())
261+ end
262+ end
263+ set_task_tid! (main_thread_task, Threads. threadid ())
264+ schedule (main_thread_task)
265+
243266 @debug " Initialized PythonCall.jl" CTX. is_embedded CTX. is_initialized CTX. exe_path CTX. lib_path CTX. lib_ptr CTX. pyprogname CTX. pyhome CTX. version
244267
245268 return
@@ -260,3 +283,26 @@ const PYTHONCALL_PKGID = Base.PkgId(PYTHONCALL_UUID, "PythonCall")
260283
261284const PYCALL_UUID = Base. UUID (" 438e738f-606a-5dbb-bf0a-cddfbfd45ab0" )
262285const PYCALL_PKGID = Base. PkgId (PYCALL_UUID, " PyCall" )
286+
287+
288+ # taken from StableTasks.jl, itself taken from Dagger.jl
289+ function set_task_tid! (task:: Task , tid:: Integer )
290+ task. sticky = true
291+ ctr = 0
292+ while true
293+ ret = ccall (:jl_set_task_tid , Cint, (Any, Cint), task, tid- 1 )
294+ if ret == 1
295+ break
296+ elseif ret == 0
297+ yield ()
298+ else
299+ error (" Unexpected retcode from jl_set_task_tid: $ret " )
300+ end
301+ ctr += 1
302+ if ctr > 10
303+ @warn " Setting task TID to $tid failed, giving up!"
304+ return
305+ end
306+ end
307+ @assert Threads. threadid (task) == tid " jl_set_task_tid failed!"
308+ end
0 commit comments