@@ -119,30 +119,23 @@ function runtime_pfor_augfwd(
119119 threading_args... ,
120120) where {ThunkTy,FT,AnyJL,byRef}
121121 TapeType = EnzymeRules. tape_type (ThunkTy)
122+
123+ n = Base. Threads. threadpoolsize ()
122124 tapes = if AnyJL
123- Vector {TapeType} (undef, Base . Threads . nthreads () )
125+ Vector {TapeType} (undef, n )
124126 else
125127 Base. unsafe_convert (
126128 Ptr{TapeType},
127- Libc. malloc (sizeof (TapeType) * Base . Threads . nthreads () ),
129+ Libc. malloc (sizeof (TapeType) * n ),
128130 )
129131 end
130132
131133 function fwd (tid_args... )
132- if length (tid_args) == 0
133- if byRef
134- tres = thunk (Const (referenceCaller), ft)
135- else
136- tres = thunk (ft)
137- end
138- tid = Base. Threads. threadid ()
134+ tid = tid_args[1 ]
135+ if byRef
136+ tres = thunk (Const (referenceCaller), ft, Const (tid))
139137 else
140- tid = tid_args[1 ]
141- if byRef
142- tres = thunk (Const (referenceCaller), ft, Const (tid))
143- else
144- tres = thunk (ft, Const (tid))
145- end
138+ tres = thunk (ft, Const (tid))
146139 end
147140
148141 if ! AnyJL
@@ -155,6 +148,29 @@ function runtime_pfor_augfwd(
155148 return tapes
156149end
157150
151+ struct ReversePFor{ThunkTy, FT, AnyJL, byRef, TT}
152+ thunk:: ThunkTy
153+ ft:: FT
154+ tapes:: TT
155+ end
156+
157+ function (st:: ReversePFor{ThunkTy, FT, AnyJL, byRef, TT} )(tid) where {ThunkTy, FT, AnyJL, byRef, TT}
158+
159+ tres = if ! AnyJL
160+ unsafe_load (st. tapes, tid)
161+ else
162+ @inbounds st. tapes[tid]
163+ end
164+
165+ if byRef
166+ st. thunk (Const (referenceCaller), st. ft, Const (tid), tres)
167+ else
168+ st. thunk (st. ft, Const (tid), tres)
169+ end
170+
171+ nothing
172+ end
173+
158174function runtime_pfor_rev (
159175 thunk:: ThunkTy ,
160176 ft:: FT ,
@@ -163,35 +179,7 @@ function runtime_pfor_rev(
163179 tapes,
164180 threading_args... ,
165181) where {ThunkTy,FT,AnyJL,byRef}
166- function rev (tid_args... )
167- tid = if length (tid_args) == 0
168- tid = Base. Threads. threadid ()
169- else
170- tid_args[1 ]
171- end
172-
173- tres = if ! AnyJL
174- unsafe_load (tapes, tid)
175- else
176- @inbounds tapes[tid]
177- end
178-
179- if length (tid_args) == 0
180- if byRef
181- thunk (Const (referenceCaller), ft, tres)
182- else
183- thunk (ft, tres)
184- end
185- else
186- if byRef
187- thunk (Const (referenceCaller), ft, Const (tid), tres)
188- else
189- thunk (ft, Const (tid), tres)
190- end
191- end
192- end
193-
194- Base. Threads. threading_run (rev, threading_args... )
182+ Base. Threads. threading_run (ReversePFor {ThunkTy, FT, AnyJL, byRef, typeof(tapes)} (thunk, ft, tapes), threading_args... )
195183 if ! AnyJL
196184 Libc. free (tapes)
197185 end
0 commit comments