@@ -90,6 +90,48 @@ function wrap_func(
9090 return nothing
9191end
9292
93+ const _INPLACEOBJECTIVE_SPEC_V8 = (
94+ fields = (:fdf , :fgh , :hvp , :fghvp , :fjvp ),
95+ x_last = (:fdf , :fgh ),
96+ xv_tail = (:hvp , :fghvp , :fjvp ),
97+ )
98+ const _INPLACEOBJECTIVE_SPEC_V7 = (
99+ fields = (:df , :fdf , :fgh , :hv , :fghv ),
100+ x_last = (:df , :fdf , :fgh ),
101+ xv_tail = (:hv , :fghv ),
102+ )
103+ const _INPLACEOBJECTIVE_SPEC_OLD = (
104+ fields = (:fdf , :fgh , :hv , :fghv ),
105+ x_last = (:fdf , :fgh ),
106+ xv_tail = (:hv , :fghv ),
107+ )
108+
109+ @inline function _wrap_inplaceobjective_field (
110+ :: Val{field} , f:: NLSolversBase.InplaceObjective , tree:: N , refs, spec
111+ ) where {field,N<: Union{AbstractExpressionNode,AbstractExpression} }
112+ if field in spec. x_last
113+ return _wrap_objective_x_last (getfield (f, field), tree, refs)
114+ elseif field in spec. xv_tail
115+ return _wrap_objective_xv_tail (getfield (f, field), tree, refs)
116+ else
117+ throw (
118+ ArgumentError (
119+ " Internal error: no wrapping rule for InplaceObjective field $(field) . " *
120+ " Please open an issue at github.com/SymbolicML/DynamicExpressions.jl with your versions." ,
121+ ),
122+ )
123+ end
124+ end
125+
126+ @inline function _wrap_inplaceobjective (
127+ f:: NLSolversBase.InplaceObjective , tree:: N , refs, spec
128+ ) where {N<: Union{AbstractExpressionNode,AbstractExpression} }
129+ wrapped = map (spec. fields) do field
130+ _wrap_inplaceobjective_field (Val (field), f, tree, refs, spec)
131+ end
132+ return NLSolversBase. InplaceObjective (wrapped... )
133+ end
134+
93135function wrap_func (
94136 f:: NLSolversBase.InplaceObjective , tree:: N , refs
95137) where {N<: Union{AbstractExpressionNode,AbstractExpression} }
@@ -99,33 +141,15 @@ function wrap_func(
99141 #
100142 # We use `@static` branching so that only the relevant layout for the *installed*
101143 # NLSolversBase version is compiled/instrumented.
102- @static if fieldnames (NLSolversBase. InplaceObjective) ==
103- (:fdf , :fgh , :hvp , :fghvp , :fjvp )
144+ @static if fieldnames (NLSolversBase. InplaceObjective) == _INPLACEOBJECTIVE_SPEC_V8. fields
104145 # NLSolversBase v8 / Optim v2
105- return NLSolversBase. InplaceObjective (
106- _wrap_objective_x_last (getfield (f, :fdf ), tree, refs),
107- _wrap_objective_x_last (getfield (f, :fgh ), tree, refs),
108- _wrap_objective_xv_tail (getfield (f, :hvp ), tree, refs),
109- _wrap_objective_xv_tail (getfield (f, :fghvp ), tree, refs),
110- _wrap_objective_xv_tail (getfield (f, :fjvp ), tree, refs),
111- )
112- elseif fieldnames (NLSolversBase. InplaceObjective) == (:df , :fdf , :fgh , :hv , :fghv )
146+ return _wrap_inplaceobjective (f, tree, refs, _INPLACEOBJECTIVE_SPEC_V8)
147+ elseif fieldnames (NLSolversBase. InplaceObjective) == _INPLACEOBJECTIVE_SPEC_V7. fields
113148 # NLSolversBase v7 / Optim v1
114- return NLSolversBase. InplaceObjective (
115- _wrap_objective_x_last (getfield (f, :df ), tree, refs),
116- _wrap_objective_x_last (getfield (f, :fdf ), tree, refs),
117- _wrap_objective_x_last (getfield (f, :fgh ), tree, refs),
118- _wrap_objective_xv_tail (getfield (f, :hv ), tree, refs),
119- _wrap_objective_xv_tail (getfield (f, :fghv ), tree, refs),
120- )
121- elseif fieldnames (NLSolversBase. InplaceObjective) == (:fdf , :fgh , :hv , :fghv )
149+ return _wrap_inplaceobjective (f, tree, refs, _INPLACEOBJECTIVE_SPEC_V7)
150+ elseif fieldnames (NLSolversBase. InplaceObjective) == _INPLACEOBJECTIVE_SPEC_OLD. fields
122151 # Older NLSolversBase / Optim
123- return NLSolversBase. InplaceObjective (
124- _wrap_objective_x_last (getfield (f, :fdf ), tree, refs),
125- _wrap_objective_x_last (getfield (f, :fgh ), tree, refs),
126- _wrap_objective_xv_tail (getfield (f, :hv ), tree, refs),
127- _wrap_objective_xv_tail (getfield (f, :fghv ), tree, refs),
128- )
152+ return _wrap_inplaceobjective (f, tree, refs, _INPLACEOBJECTIVE_SPEC_OLD)
129153 else
130154 fields = fieldnames (NLSolversBase. InplaceObjective)
131155 throw (
0 commit comments