@@ -8,69 +8,47 @@ const TAG = :ReverseAD
88
99const MAX_CHUNK = 10
1010
11- function _generate_eval_hessian ()
12- exprs = map (1 : MAX_CHUNK) do chunk
13- return :(return _eval_hessian_inner (d, f, H, λ, offset, Val ($ chunk)))
14- end
15- return Nonlinear. _create_binary_switch (1 : MAX_CHUNK, exprs)
16- end
17-
18- @eval begin
19- """
20- _eval_hessian(
21- d::NLPEvaluator,
22- f::_FunctionStorage,
23- H::AbstractVector{Float64},
24- λ::Float64,
25- offset::Int,
26- )::Int
27-
28- Evaluate the hessian matrix of the function `f` and store the result, scaled by
29- `λ`, in `H`, beginning at element `offset+1`. This function assumes that
30- `_reverse_mode(d, x)` has already been called.
31-
32- Returns the number of non-zeros in the computed Hessian, which will be used to
33- update the offset for the next call.
34- """
35- function _eval_hessian (
11+ """
12+ _eval_hessian(
3613 d::NLPEvaluator,
3714 f::_FunctionStorage,
3815 H::AbstractVector{Float64},
3916 λ::Float64,
4017 offset::Int,
4118 )::Int
42- id = min (size (f. seed_matrix, 2 ), d. max_chunk)
43- # As a performance optimization, skip dynamic dispatch if the chunk is 1.
44- $ (_generate_eval_hessian ())
45- error (" Invalid chunk `$id `. Please report this." )
46- end
47- end
4819
49- function _eval_hessian_inner (
20+ Evaluate the hessian matrix of the function `f` and store the result, scaled by
21+ `λ`, in `H`, beginning at element `offset+1`. This function assumes that
22+ `_reverse_mode(d, x)` has already been called.
23+
24+ Returns the number of non-zeros in the computed Hessian, which will be used to
25+ update the offset for the next call.
26+ """
27+ function _eval_hessian (
5028 d:: NLPEvaluator ,
5129 ex:: _FunctionStorage ,
5230 H:: AbstractVector{Float64} ,
5331 scale:: Float64 ,
5432 nzcount:: Int ,
55- :: Val{CHUNK} ,
56- ) where {CHUNK}
33+ ):: Int
5734 if ex. linearity == LINEAR
5835 @assert length (ex. hess_I) == 0
5936 return 0
6037 end
38+ chunk = min (size (ex. seed_matrix, 2 ), d. max_chunk)
6139 Coloring. prepare_seed_matrix! (ex. seed_matrix, ex. rinfo)
6240 # Compute hessian-vector products
6341 num_products = size (ex. seed_matrix, 2 ) # number of hessian-vector products
64- num_chunks = div (num_products, CHUNK )
42+ num_chunks = div (num_products, chunk )
6543 @assert size (ex. seed_matrix, 1 ) == length (ex. rinfo. local_indices)
66- for offset in 1 : CHUNK : (CHUNK * num_chunks)
67- _eval_hessian_chunk (d, ex, offset, CHUNK, Val (CHUNK) )
44+ for offset in 1 : chunk : (chunk * num_chunks)
45+ _eval_hessian_chunk (d, ex, offset, chunk, chunk )
6846 end
6947 # leftover chunk
70- remaining = num_products - CHUNK * num_chunks
48+ remaining = num_products - chunk * num_chunks
7149 if remaining > 0
72- offset = CHUNK * num_chunks + 1
73- _eval_hessian_chunk (d, ex, offset, remaining, Val (CHUNK) )
50+ offset = chunk * num_chunks + 1
51+ _eval_hessian_chunk (d, ex, offset, remaining, chunk )
7452 end
7553 want, got = nzcount + length (ex. hess_I), length (H)
7654 if want > got
@@ -98,32 +76,59 @@ function _eval_hessian_chunk(
9876 ex:: _FunctionStorage ,
9977 offset:: Int ,
10078 chunk:: Int ,
101- :: Val{CHUNK} ,
102- ) where {CHUNK}
79+ chunk_size :: Int ,
80+ )
10381 for r in eachindex (ex. rinfo. local_indices)
10482 # set up directional derivatives
10583 @inbounds idx = ex. rinfo. local_indices[r]
10684 # load up ex.seed_matrix[r,k,k+1,...,k+remaining-1] into input_ϵ
10785 for s in 1 : chunk
108- # If `chunk < CHUNK `, leaves junk in the unused components
109- d. input_ϵ[(idx- 1 )* CHUNK + s] = ex. seed_matrix[r, offset+ s- 1 ]
86+ # If `chunk < chunk_size `, leaves junk in the unused components
87+ d. input_ϵ[(idx- 1 )* chunk_size + s] = ex. seed_matrix[r, offset+ s- 1 ]
11088 end
11189 end
112- _hessian_slice_inner (d, ex, Val (CHUNK) )
90+ _hessian_slice_inner (d, ex, chunk_size )
11391 fill! (d. input_ϵ, 0.0 )
11492 # collect directional derivatives
11593 for r in eachindex (ex. rinfo. local_indices)
11694 @inbounds idx = ex. rinfo. local_indices[r]
11795 # load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
11896 for s in 1 : chunk
119- ex. seed_matrix[r, offset+ s- 1 ] = d. output_ϵ[(idx- 1 )* CHUNK + s]
97+ ex. seed_matrix[r, offset+ s- 1 ] = d. output_ϵ[(idx- 1 )* chunk_size + s]
12098 end
12199 end
122100 return
123101end
124102
125- function _hessian_slice_inner (d, ex, :: Val{CHUNK} ) where {CHUNK}
126- T = ForwardDiff. Partials{CHUNK,Float64} # This is our element type.
103+ # A wrapper function to avoid dynamic dispatch.
104+ function _hessian_slice_inner (d, ex, chunk:: Int )
105+ @assert 1 <= chunk <= MAX_CHUNK
106+ @assert MAX_CHUNK == 10
107+ if chunk == 1
108+ _hessian_slice_inner (d, ex, ForwardDiff. Partials{1 ,Float64})
109+ elseif chunk == 2
110+ _hessian_slice_inner (d, ex, ForwardDiff. Partials{2 ,Float64})
111+ elseif chunk == 3
112+ _hessian_slice_inner (d, ex, ForwardDiff. Partials{3 ,Float64})
113+ elseif chunk == 4
114+ _hessian_slice_inner (d, ex, ForwardDiff. Partials{4 ,Float64})
115+ elseif chunk == 5
116+ _hessian_slice_inner (d, ex, ForwardDiff. Partials{5 ,Float64})
117+ elseif chunk == 6
118+ _hessian_slice_inner (d, ex, ForwardDiff. Partials{6 ,Float64})
119+ elseif chunk == 7
120+ _hessian_slice_inner (d, ex, ForwardDiff. Partials{7 ,Float64})
121+ elseif chunk == 8
122+ _hessian_slice_inner (d, ex, ForwardDiff. Partials{8 ,Float64})
123+ elseif chunk == 9
124+ _hessian_slice_inner (d, ex, ForwardDiff. Partials{9 ,Float64})
125+ else
126+ _hessian_slice_inner (d, ex, ForwardDiff. Partials{10 ,Float64})
127+ end
128+ return
129+ end
130+
131+ function _hessian_slice_inner (d, ex, :: Type{T} ) where {T}
127132 fill! (d. output_ϵ, 0.0 )
128133 output_ϵ = _reinterpret_unsafe (T, d. output_ϵ)
129134 subexpr_forward_values_ϵ =
0 commit comments