Skip to content

Commit 9010b7f

Browse files
authored
rewrite jl_threading_run in julia (#35632)
1 parent 6eb96da commit 9010b7f

File tree

2 files changed

+26
-52
lines changed

2 files changed

+26
-52
lines changed

base/threadingconstructs.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,26 @@ on [`threadid()`](@ref).
1818
"""
1919
nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint)))
2020

21+
function threading_run(func)
22+
ccall(:jl_enter_threaded_region, Cvoid, ())
23+
n = nthreads()
24+
tasks = Vector{Task}(undef, n)
25+
for i = 1:n
26+
t = Task(func)
27+
t.sticky = true
28+
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, i-1)
29+
tasks[i] = t
30+
schedule(t)
31+
end
32+
try
33+
for i = 1:n
34+
wait(tasks[i])
35+
end
36+
finally
37+
ccall(:jl_exit_threaded_region, Cvoid, ())
38+
end
39+
end
40+
2141
function _threadsfor(iter,lbody)
2242
lidx = iter.args[1] # index
2343
range = iter.args[2]
@@ -66,7 +86,7 @@ function _threadsfor(iter,lbody)
6686
# only thread 1 can enter/exit _threadedregion
6787
Base.invokelatest(threadsfor_fun, true)
6888
else
69-
ccall(:jl_threading_run, Cvoid, (Any,), threadsfor_fun)
89+
threading_run(threadsfor_fun)
7090
end
7191
nothing
7292
end

src/threading.c

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -481,65 +481,19 @@ void jl_start_threads(void)
481481

482482
unsigned volatile _threadedregion; // HACK: keep track of whether it is safe to do IO
483483

484-
// simple fork/join mode code
485-
JL_DLLEXPORT void jl_threading_run(jl_value_t *func)
484+
JL_DLLEXPORT void jl_enter_threaded_region(void)
486485
{
487-
jl_ptls_t ptls = jl_get_ptls_states();
488-
int8_t gc_state = jl_gc_unsafe_enter(ptls);
489-
size_t world = jl_world_counter;
490-
jl_method_instance_t *mfunc = jl_lookup_generic(&func, 1, jl_int32hash_fast(jl_return_address()), world);
491-
// Ignore constant return value for now.
492-
jl_code_instance_t *fptr = jl_compile_method_internal(mfunc, world);
493-
if (fptr->invoke == jl_fptr_const_return)
494-
return;
495-
496-
size_t nthreads = jl_n_threads;
497-
jl_svec_t *ts = jl_alloc_svec(nthreads);
498-
JL_GC_PUSH1(&ts);
499-
jl_value_t *wait_func = jl_get_global(jl_base_module, jl_symbol("wait"));
500-
jl_value_t *schd_func = jl_get_global(jl_base_module, jl_symbol("schedule"));
501-
// create and schedule all tasks
502486
_threadedregion += 1;
503-
for (int i = 0; i < nthreads; i++) {
504-
jl_value_t *args2[2];
505-
args2[0] = (jl_value_t*)jl_task_type;
506-
args2[1] = func;
507-
jl_task_t *t = (jl_task_t*)jl_apply(args2, 2);
508-
jl_svecset(ts, i, t);
509-
t->sticky = 1;
510-
t->tid = i;
511-
args2[0] = schd_func;
512-
args2[1] = (jl_value_t*)t;
513-
jl_apply(args2, 2);
514-
if (i == 1 && nthreads > 2) {
515-
// hint to threads that work is coming soon
516-
jl_wakeup_thread(-1);
517-
}
518-
}
519-
// join with all tasks
520-
JL_TRY {
521-
for (int i = 0; i < nthreads; i++) {
522-
jl_value_t *t = jl_svecref(ts, i);
523-
jl_value_t *args[2] = { wait_func, t };
524-
jl_apply(args, 2);
525-
}
526-
}
527-
JL_CATCH {
528-
_threadedregion -= 1;
529-
jl_wake_libuv();
530-
JL_UV_LOCK();
531-
JL_UV_UNLOCK();
532-
jl_rethrow();
533-
}
534-
// make sure no threads are sitting in the event loop
487+
}
488+
489+
JL_DLLEXPORT void jl_exit_threaded_region(void)
490+
{
535491
_threadedregion -= 1;
536492
jl_wake_libuv();
537493
// make sure no more callbacks will run while user code continues
538494
// outside thread region and might touch an I/O object.
539495
JL_UV_LOCK();
540496
JL_UV_UNLOCK();
541-
JL_GC_POP();
542-
jl_gc_unsafe_leave(ptls, gc_state);
543497
}
544498

545499

0 commit comments

Comments
 (0)