Skip to content

Commit adceb0f

Browse files
committed
[Nonlinear] Merge forward_storage_ϵ with reverse_storage_ϵ
1 parent 355a039 commit adceb0f

File tree

3 files changed

+18
-23
lines changed

3 files changed

+18
-23
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ function _hessian_slice_inner(d, ex, input_ϵ, output_ϵ, ::Type{T}) where {T}
145145
_forward_eval_ϵ(
146146
d,
147147
ex,
148-
_reinterpret_unsafe(T, d.forward_storage_ϵ),
148+
_reinterpret_unsafe(T, d.storage_ϵ),
149149
_reinterpret_unsafe(T, d.partials_storage_ϵ),
150150
input_ϵ,
151151
subexpr_forward_values_ϵ,
@@ -161,7 +161,7 @@ function _hessian_slice_inner(d, ex, input_ϵ, output_ϵ, ::Type{T}) where {T}
161161
_reverse_eval_ϵ(
162162
output_ϵ,
163163
ex,
164-
_reinterpret_unsafe(T, d.reverse_storage_ϵ),
164+
_reinterpret_unsafe(T, d.storage_ϵ),
165165
_reinterpret_unsafe(T, d.partials_storage_ϵ),
166166
d.subexpression_reverse_values,
167167
subexpr_reverse_values_ϵ,

src/Nonlinear/ReverseAD/mathoptinterface_api.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -138,27 +138,25 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
138138
end
139139
# 10 is hardcoded upper bound to avoid excess memory allocation
140140
max_chunk = min(max_chunk, 10)
141+
max_expr_with_sub_length =
142+
maximum(d.subexpressions; init = max_expr_length) do subexpr
143+
return length(subexpr.nodes)
144+
end
141145
if d.want_hess || want_hess_storage
142146
d.input_ϵ = zeros(max_chunk * N)
143147
d.output_ϵ = zeros(max_chunk * N)
144148
#
145-
len = max_chunk * max_expr_length
146-
d.forward_storage_ϵ = zeros(len)
147-
d.partials_storage_ϵ = zeros(len)
148-
d.reverse_storage_ϵ = zeros(len)
149+
d.partials_storage_ϵ = zeros(max_chunk * max_expr_length)
150+
d.storage_ϵ = zeros(max_chunk * max_expr_with_sub_length)
149151
#
150152
len = max_chunk * length(d.subexpressions)
151153
d.subexpression_forward_values_ϵ = zeros(len)
152154
d.subexpression_reverse_values_ϵ = zeros(len)
153155
#
154156
for k in d.subexpression_order
155157
len = max_chunk * length(d.subexpressions[k].nodes)
156-
resize!(d.subexpressions[k].forward_storage_ϵ, len)
157-
fill!(d.subexpressions[k].forward_storage_ϵ, 0.0)
158158
resize!(d.subexpressions[k].partials_storage_ϵ, len)
159159
fill!(d.subexpressions[k].partials_storage_ϵ, 0.0)
160-
resize!(d.subexpressions[k].reverse_storage_ϵ, len)
161-
fill!(d.subexpressions[k].reverse_storage_ϵ, 0.0)
162160
end
163161
d.max_chunk = max_chunk
164162
if d.want_hess
@@ -350,7 +348,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
350348
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
351349
d,
352350
subexpr,
353-
reinterpret(T, subexpr.forward_storage_ϵ),
351+
reinterpret(T, subexpr.storage_ϵ),
354352
reinterpret(T, subexpr.partials_storage_ϵ),
355353
input_ϵ,
356354
subexpr_forward_values_ϵ,
@@ -361,13 +359,13 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
361359
subexpr_reverse_values_ϵ = reinterpret(T, d.subexpression_reverse_values_ϵ)
362360
fill!(subexpr_reverse_values_ϵ, zero(T))
363361
fill!(d.subexpression_reverse_values, 0.0)
364-
fill!(d.reverse_storage_ϵ, 0.0)
362+
fill!(d.storage_ϵ, 0.0)
365363
fill!(output_ϵ, zero(T))
366364
if d.objective !== nothing
367365
_forward_eval_ϵ(
368366
d,
369367
something(d.objective),
370-
reinterpret(T, d.forward_storage_ϵ),
368+
reinterpret(T, d.storage_ϵ),
371369
reinterpret(T, d.partials_storage_ϵ),
372370
input_ϵ,
373371
subexpr_forward_values_ϵ,
@@ -376,7 +374,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
376374
_reverse_eval_ϵ(
377375
output_ϵ,
378376
something(d.objective),
379-
reinterpret(T, d.reverse_storage_ϵ),
377+
reinterpret(T, d.storage_ϵ),
380378
reinterpret(T, d.partials_storage_ϵ),
381379
d.subexpression_reverse_values,
382380
subexpr_reverse_values_ϵ,
@@ -388,7 +386,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
388386
_forward_eval_ϵ(
389387
d,
390388
con,
391-
reinterpret(T, d.forward_storage_ϵ),
389+
reinterpret(T, d.storage_ϵ),
392390
reinterpret(T, d.partials_storage_ϵ),
393391
input_ϵ,
394392
subexpr_forward_values_ϵ,
@@ -397,7 +395,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
397395
_reverse_eval_ϵ(
398396
output_ϵ,
399397
con,
400-
reinterpret(T, d.reverse_storage_ϵ),
398+
reinterpret(T, d.storage_ϵ),
401399
reinterpret(T, d.partials_storage_ϵ),
402400
d.subexpression_reverse_values,
403401
subexpr_reverse_values_ϵ,
@@ -411,7 +409,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
411409
_reverse_eval_ϵ(
412410
output_ϵ,
413411
subexpr,
414-
reinterpret(T, subexpr.reverse_storage_ϵ),
412+
reinterpret(T, d.storage_ϵ),
415413
reinterpret(T, subexpr.partials_storage_ϵ),
416414
d.subexpression_reverse_values,
417415
subexpr_reverse_values_ϵ,

src/Nonlinear/ReverseAD/types.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ struct _SubexpressionStorage
1111
forward_storage::Vector{Float64}
1212
partials_storage::Vector{Float64}
1313
reverse_storage::Vector{Float64}
14-
forward_storage_ϵ::Vector{Float64}
1514
partials_storage_ϵ::Vector{Float64}
16-
reverse_storage_ϵ::Vector{Float64}
1715
linearity::Linearity
1816

1917
function _SubexpressionStorage(
@@ -175,11 +173,10 @@ mutable struct NLPEvaluator <: MOI.AbstractNLPEvaluator
175173
# so the length should be multiplied by the maximum number of epsilon components
176174
disable_2ndorder::Bool # don't offer Hess or HessVec
177175
want_hess::Bool
178-
forward_storage_ϵ::Vector{Float64} # (longest expression)
179-
partials_storage_ϵ::Vector{Float64} # (longest expression)
180-
reverse_storage_ϵ::Vector{Float64} # (longest expression)
176+
partials_storage_ϵ::Vector{Float64} # (longest expression excluding subexpressions)
177+
storage_ϵ::Vector{Float64} # (longest expression including subexpressions)
181178
input_ϵ::Vector{Float64} # (number of variables)
182-
output_ϵ::Vector{Float64}# (number of variables)
179+
output_ϵ::Vector{Float64} # (number of variables)
183180
subexpression_forward_values_ϵ::Vector{Float64} # (number of subexpressions)
184181
subexpression_reverse_values_ϵ::Vector{Float64} # (number of subexpressions)
185182
hessian_sparsity::Vector{Tuple{Int64,Int64}}

0 commit comments

Comments
 (0)