Skip to content

Commit ddee40c

Browse files
add native julia multi-threading
since the tetrahedral amount of work leads to a load balancing issue when chunking the iterations, a step-wise distribution of the work is used instead.
1 parent 058f0e8 commit ddee40c

File tree

7 files changed

+113
-105
lines changed

7 files changed

+113
-105
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
docs/build/
22
docs/site/
3-
*/*.dylib

deps/Makefile

Lines changed: 0 additions & 30 deletions
This file was deleted.

deps/Rotations.c

Lines changed: 0 additions & 62 deletions
This file was deleted.

deps/build.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

src/FastTransforms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ export triones, trizeros, trirand, trirandn, trievaluate
5555
#export fejer2, fejer_plan2, fejerweights2
5656
#export RecurrencePlan, forward_recurrence!, backward_recurrence
5757

58+
include("stepthreading.jl")
5859
include("fftBigFloat.jl")
5960
include("specialfunctions.jl")
6061
include("clenshawcurtis.jl")

src/SphericalHarmonics/slowplan.jl

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,53 @@ function RotationPlan(::Type{T}, n::Int) where T
9999
RotationPlan(layers, snm, cnm)
100100
end
101101

102-
const rotpath = joinpath(Pkg.dir("FastTransforms"), "deps", "rotpar")
103-
104-
function Base.A_mul_B!(P::RotationPlan{Float64}, A::AbstractMatrix{Float64})
105-
M, N = size(A)
106-
ccall((:julia_apply_givens, rotpath), Void, (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Int64, Int64), A, P.snm, P.cnm, M, N)
102+
function Base.A_mul_B!(P::RotationPlan, A::AbstractMatrix)
103+
N, M = size(A)
104+
snm = P.snm
105+
cnm = P.cnm
106+
@stepthreads for m = M÷2:-1:2
107+
@inbounds for j = m:-2:2
108+
for l = N-j:-1:1
109+
s = snm[l+(j-2)*(2*N+3-j)÷2]
110+
c = cnm[l+(j-2)*(2*N+3-j)÷2]
111+
a1 = A[l+N*(2*m-1)]
112+
a2 = A[l+2+N*(2*m-1)]
113+
a3 = A[l+N*(2*m)]
114+
a4 = A[l+2+N*(2*m)]
115+
A[l+N*(2*m-1)] = c*a1 + s*a2
116+
A[l+2+N*(2*m-1)] = c*a2 - s*a1
117+
A[l+N*(2*m)] = c*a3 + s*a4
118+
A[l+2+N*(2*m)] = c*a4 - s*a3
119+
end
120+
end
121+
end
107122
A
108123
end
109124

110-
function Base.At_mul_B!(P::RotationPlan{Float64}, A::AbstractMatrix{Float64})
111-
M, N = size(A)
112-
ccall((:julia_apply_givens_t, rotpath), Void, (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Int64, Int64), A, P.snm, P.cnm, M, N)
125+
function Base.At_mul_B!(P::RotationPlan, A::AbstractMatrix)
126+
N, M = size(A)
127+
snm = P.snm
128+
cnm = P.cnm
129+
@stepthreads for m = M÷2:-1:2
130+
@inbounds for j = reverse(m:-2:2)
131+
for l = 1:N-j
132+
s = snm[l+(j-2)*(2*N+3-j)÷2]
133+
c = cnm[l+(j-2)*(2*N+3-j)÷2]
134+
a1 = A[l+N*(2*m-1)]
135+
a2 = A[l+2+N*(2*m-1)]
136+
a3 = A[l+N*(2*m)]
137+
a4 = A[l+2+N*(2*m)]
138+
A[l+N*(2*m-1)] = c*a1 - s*a2
139+
A[l+2+N*(2*m-1)] = c*a2 + s*a1
140+
A[l+N*(2*m)] = c*a3 - s*a4
141+
A[l+2+N*(2*m)] = c*a4 + s*a3
142+
end
143+
end
144+
end
113145
A
114146
end
115147

148+
#=
116149
function Base.A_mul_B!(P::RotationPlan, A::AbstractMatrix)
117150
M, N = size(A)
118151
@inbounds for m = N÷2-2:-1:0
@@ -150,6 +183,7 @@ function Base.At_mul_B!(P::RotationPlan, A::AbstractMatrix)
150183
end
151184
A
152185
end
186+
=#
153187

154188
Base.Ac_mul_B!(P::RotationPlan, A::AbstractMatrix) = At_mul_B!(P, A)
155189

src/stepthreading.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
function _stepthreadsfor(iter,lbody)
2+
lidx = iter.args[1] # index
3+
range = iter.args[2]
4+
quote
5+
local stepthreadsfor_fun
6+
let range = $(esc(range))
7+
function stepthreadsfor_fun(onethread=false)
8+
r = range # Load into local variable
9+
lenr = length(r)
10+
# divide loop iterations among threads
11+
if onethread
12+
tid = 1
13+
len, rem = lenr, 0
14+
else
15+
tid = Threads.threadid()
16+
len, rem = divrem(lenr, Threads.nthreads())
17+
end
18+
# not enough iterations for all the threads?
19+
if len == 0
20+
if tid > rem
21+
return
22+
end
23+
len, rem = 1, 0
24+
end
25+
# compute this thread's iterations
26+
f = tid
27+
m = Threads.nthreads()
28+
l = lenr
29+
# run this thread's iterations
30+
for i = f:m:l
31+
local $(esc(lidx)) = Base.unsafe_getindex(r,i)
32+
$(esc(lbody))
33+
end
34+
end
35+
end
36+
# Hack to make nested threaded loops kinda work
37+
if Threads.threadid() != 1 || Threads.in_threaded_loop[]
38+
# We are in a nested threaded loop
39+
stepthreadsfor_fun(true)
40+
else
41+
Threads.in_threaded_loop[] = true
42+
# the ccall is not expected to throw
43+
ccall(:jl_threading_run, Ref{Void}, (Any,), stepthreadsfor_fun)
44+
Threads.in_threaded_loop[] = false
45+
end
46+
nothing
47+
end
48+
end
49+
"""
50+
@stepthreads
51+
A macro to parallelize a for-loop to run with multiple threads. This spawns `nthreads()`
52+
number of threads, splits the iteration space amongst them, and iterates in parallel.
53+
A barrier is placed at the end of the loop which waits for all the threads to finish
54+
execution, and the loop returns.
55+
"""
56+
macro stepthreads(args...)
57+
na = length(args)
58+
if na != 1
59+
throw(ArgumentError("wrong number of arguments in @stepthreads"))
60+
end
61+
ex = args[1]
62+
if !isa(ex, Expr)
63+
throw(ArgumentError("need an expression argument to @stepthreads"))
64+
end
65+
if ex.head === :for
66+
return _stepthreadsfor(ex.args[1],ex.args[2])
67+
else
68+
throw(ArgumentError("unrecognized argument to @stepthreads"))
69+
end
70+
end

0 commit comments

Comments
 (0)