Skip to content

Commit e0f6ac1

Browse files
committed
Fixes
1 parent dcdca64 commit e0f6ac1

File tree

3 files changed

+24
-17
lines changed

3 files changed

+24
-17
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function _eval_hessian(
3838
scale::Float64,
3939
nzcount::Int,
4040
)::Int
41-
if ex.linearity == LINEAR
41+
if ex.expr.linearity == LINEAR
4242
@assert length(ex.hess_I) == 0
4343
return 0
4444
end

src/Nonlinear/ReverseAD/mathoptinterface_api.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
168168
linearity,
169169
),
170170
)
171-
max_expr_length = max(max_expr_length, length(d.constraints[end].nodes))
171+
max_expr_length = max(max_expr_length, length(expr.nodes))
172172
max_chunk = max(max_chunk, size(d.constraints[end].seed_matrix, 2))
173173
end
174174
max_chunk = min(max_chunk, MAX_CHUNK)
@@ -210,7 +210,7 @@ function MOI.eval_objective(d::NLPEvaluator, x)
210210
error("No nonlinear objective.")
211211
end
212212
_reverse_mode(d, x)
213-
return something(d.objective).forward_storage[1]
213+
return something(d.objective).expr.forward_storage[1]
214214
end
215215

216216
function MOI.eval_objective_gradient(d::NLPEvaluator, g, x)
@@ -226,7 +226,7 @@ end
226226
function MOI.eval_constraint(d::NLPEvaluator, g, x)
227227
_reverse_mode(d, x)
228228
for i in 1:length(d.constraints)
229-
g[i] = d.constraints[i].forward_storage[1]
229+
g[i] = d.constraints[i].expr.forward_storage[1]
230230
end
231231
return
232232
end

src/Nonlinear/ReverseAD/reverse_mode.jl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,7 @@ Forward-mode evaluation of an expression tree given in `f`.
9898
associate storage with each edge of the DAG.
9999
"""
100100
function _forward_eval(
101-
# !!! warning
102-
# This Union depends upon _FunctionStorage and _SubexpressionStorage
103-
# having similarly named fields.
104-
f::Union{_FunctionStorage,_SubexpressionStorage},
101+
f::_SubexpressionStorage,
105102
d::NLPEvaluator,
106103
x::AbstractVector{T},
107104
)::T where {T}
@@ -289,6 +286,8 @@ function _forward_eval(
289286
return f.forward_storage[1]
290287
end
291288

289+
_forward_eval(f::_FunctionStorage, d, x) = _forward_eval(f.expr, d, x)
290+
292291
"""
293292
_reverse_eval(f::Union{_FunctionStorage,_SubexpressionStorage})
294293
@@ -297,12 +296,7 @@ Reverse-mode evaluation of an expression tree given in `f`.
297296
* This function assumes `f.partials_storage` is already updated.
298297
* This function assumes that `f.reverse_storage` has been initialized with 0.0.
299298
"""
300-
function _reverse_eval(
301-
# !!! warning
302-
# This Union depends upon _FunctionStorage and _SubexpressionStorage
303-
# having similarly named fields.
304-
f::Union{_FunctionStorage,_SubexpressionStorage},
305-
)
299+
function _reverse_eval(f::_SubexpressionStorage)
306300
@assert length(f.reverse_storage) >= length(f.nodes)
307301
@assert length(f.partials_storage) >= length(f.nodes)
308302
# f.nodes is already in order such that parents always appear before
@@ -328,6 +322,8 @@ function _reverse_eval(
328322
return
329323
end
330324

325+
_reverse_eval(f::_FunctionStorage) = _reverse_eval(f.expr)
326+
331327
"""
332328
_extract_reverse_pass(
333329
g::AbstractVector{T},
@@ -361,9 +357,20 @@ end
361357

362358
function _extract_reverse_pass_inner(
363359
output::AbstractVector{T},
364-
# !!! warning
365-
# This Union depends upon _FunctionStorage and _SubexpressionStorage
366-
# having similarly named fields.
360+
f::_FunctionStorage,
361+
subexpressions::AbstractVector{T},
362+
scale::T,
363+
) where {T}
364+
return _extract_reverse_pass_inner(
365+
output,
366+
f.expr,
367+
subexpressions,
368+
scale,
369+
)
370+
end
371+
372+
function _extract_reverse_pass_inner(
373+
output::AbstractVector{T},
367374
f::Union{_FunctionStorage,_SubexpressionStorage},
368375
subexpressions::AbstractVector{T},
369376
scale::T,

0 commit comments

Comments
 (0)