@@ -14,11 +14,16 @@ function cld_fast(n::T, d::T) where {T}
1414 x += n != d* x
1515end
1616
17+ tag (check:: Val{false} ,f,x) = nothing
18+ tag (check:: Val{true} ,f,x:: AbstractArray{V} ) where V = ForwardDiff. Tag (f, V)
19+
1720store_val! (r:: Base.RefValue{T} , x:: T ) where {T} = (r[] = x)
1821store_val! (r:: Ptr{T} , x:: T ) where {T} = Base. unsafe_store! (r, x)
1922
20- function evaluate_chunks! (f:: F , (r,Δx,x), start, stop, :: ForwardDiff.Chunk{C} ) where {F,C}
21- cfg = ForwardDiff. GradientConfig (f, x, ForwardDiff. Chunk {C} (), nothing )
23+ function evaluate_chunks! (f:: F , (r,Δx,x), start, stop, :: ForwardDiff.Chunk{C} , check:: Val{B} ) where {F,C,B}
24+ Tag = tag (check, f, x)
25+ TagType = typeof (Tag)
26+ cfg = ForwardDiff. GradientConfig (f, x, ForwardDiff. Chunk {C} (), Tag)
2227 N = length (x)
2328 last_stop = cld_fast (N, C)
2429 is_last = last_stop == stop
@@ -31,34 +36,35 @@ function evaluate_chunks!(f::F, (r,Δx,x), start, stop, ::ForwardDiff.Chunk{C})
3136 i = (c- 1 ) * C + 1
3237 ForwardDiff. seed! (xdual, x, i, seeds)
3338 ydual = f (xdual)
34- ForwardDiff. extract_gradient_chunk! (Nothing , Δx, ydual, i, C)
39+ ForwardDiff. extract_gradient_chunk! (TagType , Δx, ydual, i, C)
3540 ForwardDiff. seed! (xdual, x, i)
3641 end
3742 if is_last
3843 lastchunksize = C + N - last_stop* C
3944 lastchunkindex = N - lastchunksize + 1
4045 ForwardDiff. seed! (xdual, x, lastchunkindex, seeds, lastchunksize)
4146 _ydual = f (xdual)
42- ForwardDiff. extract_gradient_chunk! (Nothing , Δx, _ydual, lastchunkindex, lastchunksize)
47+ ForwardDiff. extract_gradient_chunk! (TagType , Δx, _ydual, lastchunkindex, lastchunksize)
4348 store_val! (r, ForwardDiff. value (_ydual))
4449 end
4550end
4651
47- function threaded_gradient! (f:: F , Δx:: AbstractVector , x:: AbstractVector , :: ForwardDiff.Chunk{C} ) where {F,C}
52+ function threaded_gradient! (f:: F , Δx:: AbstractVector , x:: AbstractVector , :: ForwardDiff.Chunk{C} , check = Val {false} () ) where {F,C}
4853 N = length (x)
4954 d = cld_fast (N, C)
5055 r = Ref {eltype(Δx)} ()
5156 batch ((d,min (d,Threads. nthreads ())), r, Δx, x) do rΔxx,start,stop
52- evaluate_chunks! (f, rΔxx, start, stop, ForwardDiff. Chunk {C} ())
57+ evaluate_chunks! (f, rΔxx, start, stop, ForwardDiff. Chunk {C} (), check )
5358 end
5459 r[]
5560end
5661
5762# ### in-place jac, out-of-place f ####
5863
59- function evaluate_jacobian_chunks! (f:: F , (Δx,x), start, stop, :: ForwardDiff.Chunk{C} ) where {F,C}
60- cfg = ForwardDiff. JacobianConfig (f, x, ForwardDiff. Chunk {C} (), nothing )
61-
64+ function evaluate_jacobian_chunks! (f:: F , (Δx,x), start, stop, :: ForwardDiff.Chunk{C} , check:: Val{B} ) where {F,C,B}
65+ Tag = tag (check, f, x)
66+ TagType = typeof (Tag)
67+ cfg = ForwardDiff. JacobianConfig (f, x, ForwardDiff. Chunk {C} (), Tag)
6268 # figure out loop bounds
6369 N = length (x)
6470 last_stop = cld_fast (N, C)
@@ -81,7 +87,7 @@ function evaluate_jacobian_chunks!(f::F, (Δx,x), start, stop, ::ForwardDiff.Chu
8187
8288 # extract part of the Jacobian
8389 Δx_reshaped = ForwardDiff. reshape_jacobian (Δx, ydual, xdual)
84- ForwardDiff. extract_jacobian_chunk! (Nothing , Δx_reshaped, ydual, i, C)
90+ ForwardDiff. extract_jacobian_chunk! (TagType , Δx_reshaped, ydual, i, C)
8591 ForwardDiff. seed! (xdual, x, i)
8692 end
8793
@@ -98,23 +104,25 @@ function evaluate_jacobian_chunks!(f::F, (Δx,x), start, stop, ::ForwardDiff.Chu
98104
99105 # extract part of the Jacobian
100106 _Δx_reshaped = ForwardDiff. reshape_jacobian (Δx, _ydual, xdual)
101- ForwardDiff. extract_jacobian_chunk! (Nothing , _Δx_reshaped, _ydual, lastchunkindex, lastchunksize)
107+ ForwardDiff. extract_jacobian_chunk! (TagType , _Δx_reshaped, _ydual, lastchunkindex, lastchunksize)
102108 end
103109end
104110
105- function threaded_jacobian! (f:: F , Δx:: AbstractArray , x:: AbstractArray , :: ForwardDiff.Chunk{C} ) where {F,C}
111+ function threaded_jacobian! (f:: F , Δx:: AbstractArray , x:: AbstractArray , :: ForwardDiff.Chunk{C} , check = Val {false} () ) where {F,C}
106112 N = length (x)
107113 d = cld_fast (N, C)
108114 batch ((d,min (d,Threads. nthreads ())), Δx, x) do Δxx,start,stop
109- evaluate_jacobian_chunks! (f, Δxx, start, stop, ForwardDiff. Chunk {C} ())
115+ evaluate_jacobian_chunks! (f, Δxx, start, stop, ForwardDiff. Chunk {C} (), check )
110116 end
111117 return Δx
112118end
113119
114120# # #### in-place jac, in-place f ####
115121
116- function evaluate_f_and_jacobian_chunks! (f!:: F , (y,Δx,x), start, stop, :: ForwardDiff.Chunk{C} ) where {F,C}
117- cfg = ForwardDiff. JacobianConfig (f!, y, x, ForwardDiff. Chunk {C} (), nothing )
122+ function evaluate_f_and_jacobian_chunks! (f!:: F , (y,Δx,x), start, stop, :: ForwardDiff.Chunk{C} , check:: Val{B} ) where {F,C,B}
123+ Tag = tag (check, f!, x)
124+ TagType = typeof (Tag)
125+ cfg = ForwardDiff. JacobianConfig (f!, y, x, ForwardDiff. Chunk {C} (), Tag)
118126
119127 # figure out loop bounds
120128 N = length (x)
@@ -138,7 +146,7 @@ function evaluate_f_and_jacobian_chunks!(f!::F, (y,Δx,x), start, stop, ::Forwar
138146 f! (ForwardDiff. seed! (ydual, y), xdual)
139147
140148 # extract part of the Jacobian
141- ForwardDiff. extract_jacobian_chunk! (Nothing , Δx_reshaped, ydual, i, C)
149+ ForwardDiff. extract_jacobian_chunk! (TagType , Δx_reshaped, ydual, i, C)
142150 ForwardDiff. seed! (xdual, x, i)
143151 end
144152
@@ -154,16 +162,16 @@ function evaluate_f_and_jacobian_chunks!(f!::F, (y,Δx,x), start, stop, ::Forwar
154162 f! (ForwardDiff. seed! (ydual, y), xdual)
155163
156164 # extract part of the Jacobian
157- ForwardDiff. extract_jacobian_chunk! (Nothing , Δx_reshaped, ydual, lastchunkindex, lastchunksize)
165+ ForwardDiff. extract_jacobian_chunk! (TagType , Δx_reshaped, ydual, lastchunkindex, lastchunksize)
158166 map! (ForwardDiff. value, y, ydual)
159167 end
160168end
161169
162- function threaded_jacobian! (f!:: F , y:: AbstractArray , Δx:: AbstractArray , x:: AbstractArray , :: ForwardDiff.Chunk{C} ) where {F,C}
170+ function threaded_jacobian! (f!:: F , y:: AbstractArray , Δx:: AbstractArray , x:: AbstractArray , :: ForwardDiff.Chunk{C} ,check = Val {false} () ) where {F,C}
163171 N = length (x)
164172 d = cld_fast (N, C)
165173 batch ((d,min (d,Threads. nthreads ())), y, Δx, x) do yΔxx,start,stop
166- evaluate_f_and_jacobian_chunks! (f!, yΔxx, start, stop, ForwardDiff. Chunk {C} ())
174+ evaluate_f_and_jacobian_chunks! (f!, yΔxx, start, stop, ForwardDiff. Chunk {C} (), check )
167175 end
168176 Δx
169177end
0 commit comments