Skip to content

Commit 13a9df6

Browse files
committed
Fixes
1 parent a80f0ab commit 13a9df6

File tree

3 files changed

+20
-42
lines changed

3 files changed

+20
-42
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,9 @@ function _hessian_slice_inner(d, ex, ::Type{T}) where {T}
128128
_reinterpret_unsafe(T, d.subexpression_forward_values_ϵ)
129129
for i in ex.dependent_subexpressions
130130
subexpr = d.subexpressions[i]
131-
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
132-
d,
133-
subexpr,
134-
_reinterpret_unsafe(T, subexpr.partials_storage_ϵ),
135-
)
131+
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(d, subexpr, T)
136132
end
137-
_forward_eval_ϵ(d, ex, _reinterpret_unsafe(T, d.partials_storage_ϵ))
133+
_forward_eval_ϵ(d, ex.expr, T)
138134
# do a reverse pass
139135
subexpr_reverse_values_ϵ =
140136
_reinterpret_unsafe(T, d.subexpression_reverse_values_ϵ)
@@ -144,9 +140,8 @@ function _hessian_slice_inner(d, ex, ::Type{T}) where {T}
144140
end
145141
_reverse_eval_ϵ(
146142
output_ϵ,
147-
ex,
143+
ex.expr,
148144
_reinterpret_unsafe(T, d.storage_ϵ),
149-
_reinterpret_unsafe(T, d.partials_storage_ϵ),
150145
d.subexpression_reverse_values,
151146
subexpr_reverse_values_ϵ,
152147
1.0,
@@ -159,7 +154,6 @@ function _hessian_slice_inner(d, ex, ::Type{T}) where {T}
159154
output_ϵ,
160155
subexpr,
161156
_reinterpret_unsafe(T, d.storage_ϵ),
162-
_reinterpret_unsafe(T, subexpr.partials_storage_ϵ),
163157
d.subexpression_reverse_values,
164158
subexpr_reverse_values_ϵ,
165159
d.subexpression_reverse_values[j],
@@ -173,8 +167,8 @@ end
173167
_forward_eval_ϵ(
174168
d::NLPEvaluator,
175169
ex::Union{_FunctionStorage,_SubexpressionStorage},
176-
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
177-
) where {N,T}
170+
::Type{P},
171+
) where {N,T,P<:ForwardDiff.Partials{N,T}}
178172
179173
Evaluate the directional derivatives of the expression tree in `ex`.
180174
@@ -186,10 +180,11 @@ This assumes that `_reverse_model(d, x)` has already been called.
186180
"""
187181
function _forward_eval_ϵ(
188182
d::NLPEvaluator,
189-
ex::Union{_FunctionStorage,_SubexpressionStorage},
190-
partials_storage_ϵ::AbstractVector{P},
183+
ex::_SubexpressionStorage,
184+
::Type{P},
191185
) where {N,T,P<:ForwardDiff.Partials{N,T}}
192186
storage_ϵ = _reinterpret_unsafe(P, d.storage_ϵ)
187+
partials_storage_ϵ = _reinterpret_unsafe(P, ex.partials_storage_ϵ)
193188
x_values_ϵ = _reinterpret_unsafe(P, d.input_ϵ)
194189
subexpression_values_ϵ =
195190
_reinterpret_unsafe(P, d.subexpression_forward_values_ϵ)
@@ -370,14 +365,17 @@ end
370365
# to compute hessian-vector products.
371366
function _reverse_eval_ϵ(
372367
output_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
373-
ex::Union{_FunctionStorage,_SubexpressionStorage},
368+
ex::_SubexpressionStorage,
374369
reverse_storage_ϵ,
375-
partials_storage_ϵ,
376370
subexpression_output,
377371
subexpression_output_ϵ,
378372
scale::T,
379373
scale_ϵ::ForwardDiff.Partials{N,T},
380374
) where {N,T}
375+
partials_storage_ϵ = _reinterpret_unsafe(
376+
ForwardDiff.Partials{N,T},
377+
ex.partials_storage_ϵ,
378+
)
381379
@assert length(reverse_storage_ϵ) >= length(ex.nodes)
382380
@assert length(partials_storage_ϵ) >= length(ex.nodes)
383381
if ex.nodes[1].type == Nonlinear.NODE_VARIABLE

src/Nonlinear/ReverseAD/mathoptinterface_api.jl

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,7 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
138138
push!(
139139
d.constraints,
140140
_FunctionStorage(
141-
_SubexpressionStorage(
142-
expr,
143-
adj,
144-
moi_index_to_consecutive_index,
145-
shared_partials_storage_ϵ,
146-
linearity[1],
147-
),
141+
subexpr,
148142
N,
149143
coloring_storage,
150144
d.want_hess,
@@ -364,11 +358,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
364358
subexpr_forward_values_ϵ = reinterpret(T, d.subexpression_forward_values_ϵ)
365359
for i in d.subexpression_order
366360
subexpr = d.subexpressions[i]
367-
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
368-
d,
369-
subexpr,
370-
reinterpret(T, subexpr.partials_storage_ϵ),
371-
)
361+
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(d, subexpr, T)
372362
end
373363
# we only need to do one reverse pass through the subexpressions as well
374364
subexpr_reverse_values_ϵ = reinterpret(T, d.subexpression_reverse_values_ϵ)
@@ -377,29 +367,23 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
377367
fill!(d.storage_ϵ, 0.0)
378368
fill!(output_ϵ, zero(T))
379369
if d.objective !== nothing
380-
_forward_eval_ϵ(
381-
d,
382-
something(d.objective),
383-
reinterpret(T, d.partials_storage_ϵ),
384-
)
370+
_forward_eval_ϵ(d, something(d.objective).expr, T)
385371
_reverse_eval_ϵ(
386372
output_ϵ,
387-
something(d.objective),
388-
reinterpret(T, d.storage_ϵ),
389-
reinterpret(T, d.partials_storage_ϵ),
373+
something(d.objective).expr,
374+
_reinterpret_unsafe(T, d.storage_ϵ),
390375
d.subexpression_reverse_values,
391376
subexpr_reverse_values_ϵ,
392377
σ,
393378
zero(T),
394379
)
395380
end
396381
for (i, con) in enumerate(d.constraints)
397-
_forward_eval_ϵ(d, con, reinterpret(T, d.partials_storage_ϵ))
382+
_forward_eval_ϵ(d, con.expr, T)
398383
_reverse_eval_ϵ(
399384
output_ϵ,
400-
con,
385+
con.expr,
401386
reinterpret(T, d.storage_ϵ),
402-
reinterpret(T, d.partials_storage_ϵ),
403387
d.subexpression_reverse_values,
404388
subexpr_reverse_values_ϵ,
405389
μ[i],
@@ -413,7 +397,6 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
413397
output_ϵ,
414398
subexpr,
415399
reinterpret(T, d.storage_ϵ),
416-
reinterpret(T, subexpr.partials_storage_ϵ),
417400
d.subexpression_reverse_values,
418401
subexpr_reverse_values_ϵ,
419402
d.subexpression_reverse_values[j],

src/Nonlinear/ReverseAD/types.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,11 @@ struct _FunctionStorage
101101
subexpression_edgelist,
102102
subexpression_variables,
103103
)
104-
@show edgelist
105104
hess_I, hess_J, rinfo = Coloring.hessian_color_preprocess(
106105
edgelist,
107106
num_variables,
108107
coloring_storage,
109108
)
110-
@show hess_I, hess_J
111109
seed_matrix = Coloring.seed_matrix(rinfo)
112110
return new(
113111
expr,
@@ -172,7 +170,6 @@ mutable struct NLPEvaluator <: MOI.AbstractNLPEvaluator
172170
# so the length should be multiplied by the maximum number of epsilon components
173171
disable_2ndorder::Bool # don't offer Hess or HessVec
174172
want_hess::Bool
175-
partials_storage_ϵ::Vector{Float64} # (longest expression excluding subexpressions)
176173
storage_ϵ::Vector{Float64} # (longest expression including subexpressions)
177174
input_ϵ::Vector{Float64} # (number of variables)
178175
output_ϵ::Vector{Float64} # (number of variables)

0 commit comments

Comments
 (0)