Skip to content

Commit d1880b8

Browse files
committed
Update
1 parent 9fa53c2 commit d1880b8

File tree

1 file changed

+53
-48
lines changed

1 file changed

+53
-48
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,69 +8,47 @@ const TAG = :ReverseAD
88

99
const MAX_CHUNK = 10
1010

11-
function _generate_eval_hessian()
12-
exprs = map(1:MAX_CHUNK) do chunk
13-
return :(return _eval_hessian_inner(d, f, H, λ, offset, Val($chunk)))
14-
end
15-
return Nonlinear._create_binary_switch(1:MAX_CHUNK, exprs)
16-
end
17-
18-
@eval begin
19-
"""
20-
_eval_hessian(
21-
d::NLPEvaluator,
22-
f::_FunctionStorage,
23-
H::AbstractVector{Float64},
24-
λ::Float64,
25-
offset::Int,
26-
)::Int
27-
28-
Evaluate the hessian matrix of the function `f` and store the result, scaled by
29-
`λ`, in `H`, beginning at element `offset+1`. This function assumes that
30-
`_reverse_mode(d, x)` has already been called.
31-
32-
Returns the number of non-zeros in the computed Hessian, which will be used to
33-
update the offset for the next call.
34-
"""
35-
function _eval_hessian(
11+
"""
12+
_eval_hessian(
3613
d::NLPEvaluator,
3714
f::_FunctionStorage,
3815
H::AbstractVector{Float64},
3916
λ::Float64,
4017
offset::Int,
4118
)::Int
42-
id = min(size(f.seed_matrix, 2), d.max_chunk)
43-
# As a performance optimization, skip dynamic dispatch if the chunk is 1.
44-
$(_generate_eval_hessian())
45-
error("Invalid chunk `$id`. Please report this.")
46-
end
47-
end
4819
49-
function _eval_hessian_inner(
20+
Evaluate the hessian matrix of the function `f` and store the result, scaled by
21+
`λ`, in `H`, beginning at element `offset+1`. This function assumes that
22+
`_reverse_mode(d, x)` has already been called.
23+
24+
Returns the number of non-zeros in the computed Hessian, which will be used to
25+
update the offset for the next call.
26+
"""
27+
function _eval_hessian(
5028
d::NLPEvaluator,
5129
ex::_FunctionStorage,
5230
H::AbstractVector{Float64},
5331
scale::Float64,
5432
nzcount::Int,
55-
::Val{CHUNK},
56-
) where {CHUNK}
33+
)::Int
5734
if ex.linearity == LINEAR
5835
@assert length(ex.hess_I) == 0
5936
return 0
6037
end
38+
chunk = min(size(ex.seed_matrix, 2), d.max_chunk)
6139
Coloring.prepare_seed_matrix!(ex.seed_matrix, ex.rinfo)
6240
# Compute hessian-vector products
6341
num_products = size(ex.seed_matrix, 2) # number of hessian-vector products
64-
num_chunks = div(num_products, CHUNK)
42+
num_chunks = div(num_products, chunk)
6543
@assert size(ex.seed_matrix, 1) == length(ex.rinfo.local_indices)
66-
for offset in 1:CHUNK:(CHUNK*num_chunks)
67-
_eval_hessian_chunk(d, ex, offset, CHUNK, Val(CHUNK))
44+
for offset in 1:chunk:(chunk*num_chunks)
45+
_eval_hessian_chunk(d, ex, offset, chunk, chunk)
6846
end
6947
# leftover chunk
70-
remaining = num_products - CHUNK * num_chunks
48+
remaining = num_products - chunk * num_chunks
7149
if remaining > 0
72-
offset = CHUNK * num_chunks + 1
73-
_eval_hessian_chunk(d, ex, offset, remaining, Val(CHUNK))
50+
offset = chunk * num_chunks + 1
51+
_eval_hessian_chunk(d, ex, offset, remaining, chunk)
7452
end
7553
want, got = nzcount + length(ex.hess_I), length(H)
7654
if want > got
@@ -98,32 +76,59 @@ function _eval_hessian_chunk(
9876
ex::_FunctionStorage,
9977
offset::Int,
10078
chunk::Int,
101-
::Val{CHUNK},
102-
) where {CHUNK}
79+
chunk_size::Int,
80+
)
10381
for r in eachindex(ex.rinfo.local_indices)
10482
# set up directional derivatives
10583
@inbounds idx = ex.rinfo.local_indices[r]
10684
# load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ
10785
for s in 1:chunk
108-
# If `chunk < CHUNK`, leaves junk in the unused components
109-
d.input_ϵ[(idx-1)*CHUNK+s] = ex.seed_matrix[r, offset+s-1]
86+
# If `chunk < chunk_size`, leaves junk in the unused components
87+
d.input_ϵ[(idx-1)*chunk_size+s] = ex.seed_matrix[r, offset+s-1]
11088
end
11189
end
112-
_hessian_slice_inner(d, ex, Val(CHUNK))
90+
_hessian_slice_inner(d, ex, chunk_size)
11391
fill!(d.input_ϵ, 0.0)
11492
# collect directional derivatives
11593
for r in eachindex(ex.rinfo.local_indices)
11694
@inbounds idx = ex.rinfo.local_indices[r]
11795
# load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
11896
for s in 1:chunk
119-
ex.seed_matrix[r, offset+s-1] = d.output_ϵ[(idx-1)*CHUNK+s]
97+
ex.seed_matrix[r, offset+s-1] = d.output_ϵ[(idx-1)*chunk_size+s]
12098
end
12199
end
122100
return
123101
end
124102

125-
function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
126-
T = ForwardDiff.Partials{CHUNK,Float64} # This is our element type.
103+
# A wrapper function to avoid dynamic dispatch.
104+
function _hessian_slice_inner(d, ex, chunk::Int)
105+
@assert 1 <= chunk <= MAX_CHUNK
106+
@assert MAX_CHUNK == 10
107+
if chunk == 1
108+
_hessian_slice_inner(d, ex, ForwardDiff.Partials{1,Float64})
109+
elseif chunk == 2
110+
_hessian_slice_inner(d, ex, ForwardDiff.Partials{2,Float64})
111+
elseif chunk == 3
112+
_hessian_slice_inner(d, ex, ForwardDiff.Partials{3,Float64})
113+
elseif chunk == 4
114+
_hessian_slice_inner(d, ex, ForwardDiff.Partials{4,Float64})
115+
elseif chunk == 5
116+
_hessian_slice_inner(d, ex, ForwardDiff.Partials{5,Float64})
117+
elseif chunk == 6
118+
_hessian_slice_inner(d, ex, ForwardDiff.Partials{6,Float64})
119+
elseif chunk == 7
120+
_hessian_slice_inner(d, ex, ForwardDiff.Partials{7,Float64})
121+
elseif chunk == 8
122+
_hessian_slice_inner(d, ex, ForwardDiff.Partials{8,Float64})
123+
elseif chunk == 9
124+
_hessian_slice_inner(d, ex, ForwardDiff.Partials{9,Float64})
125+
else
126+
_hessian_slice_inner(d, ex, ForwardDiff.Partials{10,Float64})
127+
end
128+
return
129+
end
130+
131+
function _hessian_slice_inner(d, ex, ::Type{T}) where {T}
127132
fill!(d.output_ϵ, 0.0)
128133
output_ϵ = _reinterpret_unsafe(T, d.output_ϵ)
129134
subexpr_forward_values_ϵ =

0 commit comments

Comments
 (0)