Skip to content

Commit 35c7f6a

Browse files
Refactor InplaceObjective wrapping
1 parent 6ff316b commit 35c7f6a

File tree

1 file changed

+48
-24
lines changed

1 file changed

+48
-24
lines changed

ext/DynamicExpressionsOptimExt.jl

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,48 @@ function wrap_func(
9090
return nothing
9191
end
9292

93+
const _INPLACEOBJECTIVE_SPEC_V8 = (
94+
fields = (:fdf, :fgh, :hvp, :fghvp, :fjvp),
95+
x_last = (:fdf, :fgh),
96+
xv_tail = (:hvp, :fghvp, :fjvp),
97+
)
98+
const _INPLACEOBJECTIVE_SPEC_V7 = (
99+
fields = (:df, :fdf, :fgh, :hv, :fghv),
100+
x_last = (:df, :fdf, :fgh),
101+
xv_tail = (:hv, :fghv),
102+
)
103+
const _INPLACEOBJECTIVE_SPEC_OLD = (
104+
fields = (:fdf, :fgh, :hv, :fghv),
105+
x_last = (:fdf, :fgh),
106+
xv_tail = (:hv, :fghv),
107+
)
108+
109+
@inline function _wrap_inplaceobjective_field(
110+
::Val{field}, f::NLSolversBase.InplaceObjective, tree::N, refs, spec
111+
) where {field,N<:Union{AbstractExpressionNode,AbstractExpression}}
112+
if field in spec.x_last
113+
return _wrap_objective_x_last(getfield(f, field), tree, refs)
114+
elseif field in spec.xv_tail
115+
return _wrap_objective_xv_tail(getfield(f, field), tree, refs)
116+
else
117+
throw(
118+
ArgumentError(
119+
"Internal error: no wrapping rule for InplaceObjective field $(field). " *
120+
"Please open an issue at github.com/SymbolicML/DynamicExpressions.jl with your versions.",
121+
),
122+
)
123+
end
124+
end
125+
126+
@inline function _wrap_inplaceobjective(
127+
f::NLSolversBase.InplaceObjective, tree::N, refs, spec
128+
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
129+
wrapped = map(spec.fields) do field
130+
_wrap_inplaceobjective_field(Val(field), f, tree, refs, spec)
131+
end
132+
return NLSolversBase.InplaceObjective(wrapped...)
133+
end
134+
93135
function wrap_func(
94136
f::NLSolversBase.InplaceObjective, tree::N, refs
95137
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
@@ -99,33 +141,15 @@ function wrap_func(
99141
#
100142
# We use `@static` branching so that only the relevant layout for the *installed*
101143
# NLSolversBase version is compiled/instrumented.
102-
@static if fieldnames(NLSolversBase.InplaceObjective) ==
103-
(:fdf, :fgh, :hvp, :fghvp, :fjvp)
144+
@static if fieldnames(NLSolversBase.InplaceObjective) == _INPLACEOBJECTIVE_SPEC_V8.fields
104145
# NLSolversBase v8 / Optim v2
105-
return NLSolversBase.InplaceObjective(
106-
_wrap_objective_x_last(getfield(f, :fdf), tree, refs),
107-
_wrap_objective_x_last(getfield(f, :fgh), tree, refs),
108-
_wrap_objective_xv_tail(getfield(f, :hvp), tree, refs),
109-
_wrap_objective_xv_tail(getfield(f, :fghvp), tree, refs),
110-
_wrap_objective_xv_tail(getfield(f, :fjvp), tree, refs),
111-
)
112-
elseif fieldnames(NLSolversBase.InplaceObjective) == (:df, :fdf, :fgh, :hv, :fghv)
146+
return _wrap_inplaceobjective(f, tree, refs, _INPLACEOBJECTIVE_SPEC_V8)
147+
elseif fieldnames(NLSolversBase.InplaceObjective) == _INPLACEOBJECTIVE_SPEC_V7.fields
113148
# NLSolversBase v7 / Optim v1
114-
return NLSolversBase.InplaceObjective(
115-
_wrap_objective_x_last(getfield(f, :df), tree, refs),
116-
_wrap_objective_x_last(getfield(f, :fdf), tree, refs),
117-
_wrap_objective_x_last(getfield(f, :fgh), tree, refs),
118-
_wrap_objective_xv_tail(getfield(f, :hv), tree, refs),
119-
_wrap_objective_xv_tail(getfield(f, :fghv), tree, refs),
120-
)
121-
elseif fieldnames(NLSolversBase.InplaceObjective) == (:fdf, :fgh, :hv, :fghv)
149+
return _wrap_inplaceobjective(f, tree, refs, _INPLACEOBJECTIVE_SPEC_V7)
150+
elseif fieldnames(NLSolversBase.InplaceObjective) == _INPLACEOBJECTIVE_SPEC_OLD.fields
122151
# Older NLSolversBase / Optim
123-
return NLSolversBase.InplaceObjective(
124-
_wrap_objective_x_last(getfield(f, :fdf), tree, refs),
125-
_wrap_objective_x_last(getfield(f, :fgh), tree, refs),
126-
_wrap_objective_xv_tail(getfield(f, :hv), tree, refs),
127-
_wrap_objective_xv_tail(getfield(f, :fghv), tree, refs),
128-
)
152+
return _wrap_inplaceobjective(f, tree, refs, _INPLACEOBJECTIVE_SPEC_OLD)
129153
else
130154
fields = fieldnames(NLSolversBase.InplaceObjective)
131155
throw(

0 commit comments

Comments
 (0)