Skip to content

Commit 9a42901

Browse files
wsmosesTestHit
andauthored
Winf (#2611)
* fix * fix * Parallel thread cleanup --------- Co-authored-by: William S. Moses <[email protected]>
1 parent f44193a commit 9a42901

File tree

1 file changed

+32
-44
lines changed

1 file changed

+32
-44
lines changed

src/rules/parallelrules.jl

Lines changed: 32 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -119,30 +119,23 @@ function runtime_pfor_augfwd(
119119
threading_args...,
120120
) where {ThunkTy,FT,AnyJL,byRef}
121121
TapeType = EnzymeRules.tape_type(ThunkTy)
122+
123+
n = Base.Threads.threadpoolsize()
122124
tapes = if AnyJL
123-
Vector{TapeType}(undef, Base.Threads.nthreads())
125+
Vector{TapeType}(undef, n)
124126
else
125127
Base.unsafe_convert(
126128
Ptr{TapeType},
127-
Libc.malloc(sizeof(TapeType) * Base.Threads.nthreads()),
129+
Libc.malloc(sizeof(TapeType) * n),
128130
)
129131
end
130132

131133
function fwd(tid_args...)
132-
if length(tid_args) == 0
133-
if byRef
134-
tres = thunk(Const(referenceCaller), ft)
135-
else
136-
tres = thunk(ft)
137-
end
138-
tid = Base.Threads.threadid()
134+
tid = tid_args[1]
135+
if byRef
136+
tres = thunk(Const(referenceCaller), ft, Const(tid))
139137
else
140-
tid = tid_args[1]
141-
if byRef
142-
tres = thunk(Const(referenceCaller), ft, Const(tid))
143-
else
144-
tres = thunk(ft, Const(tid))
145-
end
138+
tres = thunk(ft, Const(tid))
146139
end
147140

148141
if !AnyJL
@@ -155,6 +148,29 @@ function runtime_pfor_augfwd(
155148
return tapes
156149
end
157150

151+
struct ReversePFor{ThunkTy, FT, AnyJL, byRef, TT}
152+
thunk::ThunkTy
153+
ft::FT
154+
tapes::TT
155+
end
156+
157+
function (st::ReversePFor{ThunkTy, FT, AnyJL, byRef, TT})(tid) where {ThunkTy, FT, AnyJL, byRef, TT}
158+
159+
tres = if !AnyJL
160+
unsafe_load(st.tapes, tid)
161+
else
162+
@inbounds st.tapes[tid]
163+
end
164+
165+
if byRef
166+
st.thunk(Const(referenceCaller), st.ft, Const(tid), tres)
167+
else
168+
st.thunk(st.ft, Const(tid), tres)
169+
end
170+
171+
nothing
172+
end
173+
158174
function runtime_pfor_rev(
159175
thunk::ThunkTy,
160176
ft::FT,
@@ -163,35 +179,7 @@ function runtime_pfor_rev(
163179
tapes,
164180
threading_args...,
165181
) where {ThunkTy,FT,AnyJL,byRef}
166-
function rev(tid_args...)
167-
tid = if length(tid_args) == 0
168-
tid = Base.Threads.threadid()
169-
else
170-
tid_args[1]
171-
end
172-
173-
tres = if !AnyJL
174-
unsafe_load(tapes, tid)
175-
else
176-
@inbounds tapes[tid]
177-
end
178-
179-
if length(tid_args) == 0
180-
if byRef
181-
thunk(Const(referenceCaller), ft, tres)
182-
else
183-
thunk(ft, tres)
184-
end
185-
else
186-
if byRef
187-
thunk(Const(referenceCaller), ft, Const(tid), tres)
188-
else
189-
thunk(ft, Const(tid), tres)
190-
end
191-
end
192-
end
193-
194-
Base.Threads.threading_run(rev, threading_args...)
182+
Base.Threads.threading_run(ReversePFor{ThunkTy, FT, AnyJL, byRef, typeof(tapes)}(thunk, ft, tapes), threading_args...)
195183
if !AnyJL
196184
Libc.free(tapes)
197185
end

0 commit comments

Comments
 (0)