Skip to content

Commit c22b173

Browse files
Merge pull request #21 from longemen3000/tag
add check argument
2 parents fbad79f + b87c5b1 commit c22b173

File tree

3 files changed

+33
-19
lines changed

3 files changed

+33
-19
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ It's pretty much the same as ForwardDiff.jl except it is threaded. The API is th
1717
PolyesterForwardDiff.threaded_gradient!(f, dx, x, ForwardDiff.Chunk(8));
1818
PolyesterForwardDiff.threaded_jacobian!(g, dx, x, ForwardDiff.Chunk(8));
1919
PolyesterForwardDiff.threaded_jacobian!(g!, y, dx, x, ForwardDiff.Chunk(8));
20+
PolyesterForwardDiff.threaded_gradient!(f, dx, x, ForwardDiff.Chunk(8),Val{true}()); #To enable tag checking
2021
```
2122

2223
## Citing

src/PolyesterForwardDiff.jl

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@ function cld_fast(n::T, d::T) where {T}
1414
x += n != d*x
1515
end
1616

17+
tag(check::Val{false},f,x) = nothing
18+
tag(check::Val{true},f,x::AbstractArray{V}) where V = ForwardDiff.Tag(f, V)
19+
1720
store_val!(r::Base.RefValue{T}, x::T) where {T} = (r[] = x)
1821
store_val!(r::Ptr{T}, x::T) where {T} = Base.unsafe_store!(r, x)
1922

20-
function evaluate_chunks!(f::F, (r,Δx,x), start, stop, ::ForwardDiff.Chunk{C}) where {F,C}
21-
cfg = ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{C}(), nothing)
23+
function evaluate_chunks!(f::F, (r,Δx,x), start, stop, ::ForwardDiff.Chunk{C}, check::Val{B}) where {F,C,B}
24+
Tag = tag(check, f, x)
25+
TagType = typeof(Tag)
26+
cfg = ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{C}(), Tag)
2227
N = length(x)
2328
last_stop = cld_fast(N, C)
2429
is_last = last_stop == stop
@@ -31,34 +36,35 @@ function evaluate_chunks!(f::F, (r,Δx,x), start, stop, ::ForwardDiff.Chunk{C})
3136
i = (c-1) * C + 1
3237
ForwardDiff.seed!(xdual, x, i, seeds)
3338
ydual = f(xdual)
34-
ForwardDiff.extract_gradient_chunk!(Nothing, Δx, ydual, i, C)
39+
ForwardDiff.extract_gradient_chunk!(TagType, Δx, ydual, i, C)
3540
ForwardDiff.seed!(xdual, x, i)
3641
end
3742
if is_last
3843
lastchunksize = C + N - last_stop*C
3944
lastchunkindex = N - lastchunksize + 1
4045
ForwardDiff.seed!(xdual, x, lastchunkindex, seeds, lastchunksize)
4146
_ydual = f(xdual)
42-
ForwardDiff.extract_gradient_chunk!(Nothing, Δx, _ydual, lastchunkindex, lastchunksize)
47+
ForwardDiff.extract_gradient_chunk!(TagType, Δx, _ydual, lastchunkindex, lastchunksize)
4348
store_val!(r, ForwardDiff.value(_ydual))
4449
end
4550
end
4651

47-
function threaded_gradient!(f::F, Δx::AbstractVector, x::AbstractVector, ::ForwardDiff.Chunk{C}) where {F,C}
52+
function threaded_gradient!(f::F, Δx::AbstractVector, x::AbstractVector, ::ForwardDiff.Chunk{C}, check = Val{false}()) where {F,C}
4853
N = length(x)
4954
d = cld_fast(N, C)
5055
r = Ref{eltype(Δx)}()
5156
batch((d,min(d,Threads.nthreads())), r, Δx, x) do rΔxx,start,stop
52-
evaluate_chunks!(f, rΔxx, start, stop, ForwardDiff.Chunk{C}())
57+
evaluate_chunks!(f, rΔxx, start, stop, ForwardDiff.Chunk{C}(), check)
5358
end
5459
r[]
5560
end
5661

5762
#### in-place jac, out-of-place f ####
5863

59-
function evaluate_jacobian_chunks!(f::F, (Δx,x), start, stop, ::ForwardDiff.Chunk{C}) where {F,C}
60-
cfg = ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk{C}(), nothing)
61-
64+
function evaluate_jacobian_chunks!(f::F, (Δx,x), start, stop, ::ForwardDiff.Chunk{C}, check::Val{B}) where {F,C,B}
65+
Tag = tag(check, f, x)
66+
TagType = typeof(Tag)
67+
cfg = ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk{C}(), Tag)
6268
# figure out loop bounds
6369
N = length(x)
6470
last_stop = cld_fast(N, C)
@@ -81,7 +87,7 @@ function evaluate_jacobian_chunks!(f::F, (Δx,x), start, stop, ::ForwardDiff.Chu
8187

8288
# extract part of the Jacobian
8389
Δx_reshaped = ForwardDiff.reshape_jacobian(Δx, ydual, xdual)
84-
ForwardDiff.extract_jacobian_chunk!(Nothing, Δx_reshaped, ydual, i, C)
90+
ForwardDiff.extract_jacobian_chunk!(TagType, Δx_reshaped, ydual, i, C)
8591
ForwardDiff.seed!(xdual, x, i)
8692
end
8793

@@ -98,23 +104,25 @@ function evaluate_jacobian_chunks!(f::F, (Δx,x), start, stop, ::ForwardDiff.Chu
98104

99105
# extract part of the Jacobian
100106
_Δx_reshaped = ForwardDiff.reshape_jacobian(Δx, _ydual, xdual)
101-
ForwardDiff.extract_jacobian_chunk!(Nothing, _Δx_reshaped, _ydual, lastchunkindex, lastchunksize)
107+
ForwardDiff.extract_jacobian_chunk!(TagType, _Δx_reshaped, _ydual, lastchunkindex, lastchunksize)
102108
end
103109
end
104110

105-
function threaded_jacobian!(f::F, Δx::AbstractArray, x::AbstractArray, ::ForwardDiff.Chunk{C}) where {F,C}
111+
function threaded_jacobian!(f::F, Δx::AbstractArray, x::AbstractArray, ::ForwardDiff.Chunk{C}, check = Val{false}()) where {F,C}
106112
N = length(x)
107113
d = cld_fast(N, C)
108114
batch((d,min(d,Threads.nthreads())), Δx, x) do Δxx,start,stop
109-
evaluate_jacobian_chunks!(f, Δxx, start, stop, ForwardDiff.Chunk{C}())
115+
evaluate_jacobian_chunks!(f, Δxx, start, stop, ForwardDiff.Chunk{C}(), check)
110116
end
111117
return Δx
112118
end
113119

114120
# # #### in-place jac, in-place f ####
115121

116-
function evaluate_f_and_jacobian_chunks!(f!::F, (y,Δx,x), start, stop, ::ForwardDiff.Chunk{C}) where {F,C}
117-
cfg = ForwardDiff.JacobianConfig(f!, y, x, ForwardDiff.Chunk{C}(), nothing)
122+
function evaluate_f_and_jacobian_chunks!(f!::F, (y,Δx,x), start, stop, ::ForwardDiff.Chunk{C}, check::Val{B}) where {F,C,B}
123+
Tag = tag(check, f!, x)
124+
TagType = typeof(Tag)
125+
cfg = ForwardDiff.JacobianConfig(f!, y, x, ForwardDiff.Chunk{C}(), Tag)
118126

119127
# figure out loop bounds
120128
N = length(x)
@@ -138,7 +146,7 @@ function evaluate_f_and_jacobian_chunks!(f!::F, (y,Δx,x), start, stop, ::Forwar
138146
f!(ForwardDiff.seed!(ydual, y), xdual)
139147

140148
# extract part of the Jacobian
141-
ForwardDiff.extract_jacobian_chunk!(Nothing, Δx_reshaped, ydual, i, C)
149+
ForwardDiff.extract_jacobian_chunk!(TagType, Δx_reshaped, ydual, i, C)
142150
ForwardDiff.seed!(xdual, x, i)
143151
end
144152

@@ -154,16 +162,16 @@ function evaluate_f_and_jacobian_chunks!(f!::F, (y,Δx,x), start, stop, ::Forwar
154162
f!(ForwardDiff.seed!(ydual, y), xdual)
155163

156164
# extract part of the Jacobian
157-
ForwardDiff.extract_jacobian_chunk!(Nothing, Δx_reshaped, ydual, lastchunkindex, lastchunksize)
165+
ForwardDiff.extract_jacobian_chunk!(TagType, Δx_reshaped, ydual, lastchunkindex, lastchunksize)
158166
map!(ForwardDiff.value, y, ydual)
159167
end
160168
end
161169

162-
function threaded_jacobian!(f!::F, y::AbstractArray, Δx::AbstractArray, x::AbstractArray, ::ForwardDiff.Chunk{C}) where {F,C}
170+
function threaded_jacobian!(f!::F, y::AbstractArray, Δx::AbstractArray, x::AbstractArray, ::ForwardDiff.Chunk{C},check = Val{false}()) where {F,C}
163171
N = length(x)
164172
d = cld_fast(N, C)
165173
batch((d,min(d,Threads.nthreads())), y, Δx, x) do yΔxx,start,stop
166-
evaluate_f_and_jacobian_chunks!(f!, yΔxx, start, stop, ForwardDiff.Chunk{C}())
174+
evaluate_f_and_jacobian_chunks!(f!, yΔxx, start, stop, ForwardDiff.Chunk{C}(), check)
167175
end
168176
Δx
169177
end

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,13 @@ g(x) = A*x
2727
PolyesterForwardDiff.threaded_jacobian!(g, dx, x, ForwardDiff.Chunk(8));
2828
ForwardDiff.jacobian!(dxref, g, x, ForwardDiff.JacobianConfig(g, x, ForwardDiff.Chunk(8), nothing));
2929
@test dx dxref
30+
PolyesterForwardDiff.threaded_jacobian!(g, dx, x, ForwardDiff.Chunk(8),Val{true}());
31+
@test dx dxref
3032

3133
PolyesterForwardDiff.threaded_jacobian!(g!, y, dx, x, ForwardDiff.Chunk(8));
3234
ForwardDiff.jacobian!(dxref, g!, yref, x, ForwardDiff.JacobianConfig(g!, yref, x, ForwardDiff.Chunk(8), nothing));
3335
@test dx dxref
3436
@test y yref
37+
PolyesterForwardDiff.threaded_jacobian!(g!, y, dx, x, ForwardDiff.Chunk(8),Val{true}());
38+
@test dx dxref
39+
@test y yref

0 commit comments

Comments
 (0)