Skip to content

Commit d4f354d

Browse files
committed
Remove dynamic dispatch in Hessian evaluation
1 parent bd0bf71 commit d4f354d

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,41 @@
66

77
const TAG = :ReverseAD
88

9-
"""
10-
_eval_hessian(
9+
function _generate_eval_hessian()
10+
exprs = map(1:10) do chunk
11+
return :(return _eval_hessian_inner(d, f, H, λ, offset, Val($chunk)))
12+
end
13+
return Nonlinear._create_binary_switch(1:10, exprs)
14+
end
15+
16+
@eval begin
17+
"""
18+
_eval_hessian(
19+
d::NLPEvaluator,
20+
f::_FunctionStorage,
21+
H::AbstractVector{Float64},
22+
λ::Float64,
23+
offset::Int,
24+
)::Int
25+
26+
Evaluate the hessian matrix of the function `f` and store the result, scaled by
27+
`λ`, in `H`, beginning at element `offset+1`. This function assumes that
28+
`_reverse_mode(d, x)` has already been called.
29+
30+
Returns the number of non-zeros in the computed Hessian, which will be used to
31+
update the offset for the next call.
32+
"""
33+
function _eval_hessian(
1134
d::NLPEvaluator,
1235
f::_FunctionStorage,
1336
H::AbstractVector{Float64},
1437
λ::Float64,
1538
offset::Int,
1639
)::Int
17-
18-
Evaluate the hessian matrix of the function `f` and store the result, scaled by
19-
`λ`, in `H`, beginning at element `offset+1`. This function assumes that
20-
`_reverse_mode(d, x)` has already been called.
21-
22-
Returns the number of non-zeros in the computed Hessian, which will be used to
23-
update the offset for the next call.
24-
"""
25-
function _eval_hessian(
26-
d::NLPEvaluator,
27-
f::_FunctionStorage,
28-
H::AbstractVector{Float64},
29-
λ::Float64,
30-
offset::Int,
31-
)::Int
32-
chunk = min(size(f.seed_matrix, 2), d.max_chunk)
33-
# As a performance optimization, skip dynamic dispatch if the chunk is 1.
34-
if chunk == 1
35-
return _eval_hessian_inner(d, f, H, λ, offset, Val(1))
36-
else
37-
return _eval_hessian_inner(d, f, H, λ, offset, Val(chunk))
40+
id = min(size(f.seed_matrix, 2), d.max_chunk)
41+
# As a performance optimization, skip dynamic dispatch if the chunk is 1.
42+
$(_generate_eval_hessian())
43+
error("Invalid chunk `$id`. Please report this.")
3844
end
3945
end
4046

0 commit comments

Comments
 (0)