@@ -9,7 +9,8 @@ using DynamicExpressions:
99 set_scalar_constants!,
1010 get_number_type
1111
12- import Optim: Optim, OptimizationResults, NLSolversBase
12+ import Optim: Optim, OptimizationResults
13+ using NLSolversBase: NLSolversBase
1314
1415# ! format: off
1516"""
@@ -38,41 +39,136 @@ function Optim.minimizer(r::ExpressionOptimizationResults)
3839end
3940
4041""" Wrap function or objective with insertion of values of the constant nodes."""
41- function wrap_func (
42+ @inline function _wrap_objective_x_last (
43+ :: Nothing , tree:: N , refs
44+ ) where {N<: Union{AbstractExpressionNode,AbstractExpression} }
45+ return nothing
46+ end
47+ @inline function _wrap_objective_x_last (
4248 f:: F , tree:: N , refs
4349) where {F<: Function ,T,N<: Union{AbstractExpressionNode{T},AbstractExpression{T}} }
4450 function wrapped_f (args:: Vararg{Any,M} ) where {M}
45- first_args = args[begin : (end - 1 )]
46- x = args[end ]
51+ x = args[M]
4752 set_scalar_constants! (tree, x, refs)
48- return @inline (f (first_args... , tree))
53+ newargs = Base. setindex (args, tree, M)
54+ return @inline (f (newargs... ))
4955 end
50- # without first args, it looks like this
51- # function wrapped_f(x)
52- # set_scalar_constants!(tree, x, refs)
53- # return @inline(f(tree))
54- # end
5556 return wrapped_f
5657end
58+
59+ @inline function _wrap_objective_xv_tail (
60+ :: Nothing , tree:: N , refs
61+ ) where {N<: Union{AbstractExpressionNode,AbstractExpression} }
62+ return nothing
63+ end
64+ @inline function _wrap_objective_xv_tail (
65+ f:: F , tree:: N , refs
66+ ) where {F<: Function ,T,N<: Union{AbstractExpressionNode{T},AbstractExpression{T}} }
67+ function wrapped_f (args:: Vararg{Any,M} ) where {M}
68+ if M < 2
69+ throw (
70+ ArgumentError (
71+ " Expected at least 2 arguments for objective functions of the form (..., x, v)." ,
72+ ),
73+ )
74+ end
75+ x = args[M - 1 ]
76+ set_scalar_constants! (tree, x, refs)
77+ newargs = Base. setindex (args, tree, M - 1 )
78+ return @inline (f (newargs... ))
79+ end
80+ return wrapped_f
81+ end
82+
83+ function wrap_func (
84+ f:: F , tree:: N , refs
85+ ) where {F<: Function ,T,N<: Union{AbstractExpressionNode{T},AbstractExpression{T}} }
86+ return _wrap_objective_x_last (f, tree, refs)
87+ end
5788function wrap_func (
5889 :: Nothing , tree:: N , refs
5990) where {N<: Union{AbstractExpressionNode,AbstractExpression} }
6091 return nothing
6192end
93+
94+ # `NLSolversBase.InplaceObjective` is an internal type whose field layout changed
95+ # between NLSolversBase versions (and therefore between Optim majors).
96+ #
97+ # This extension supports:
98+ # - Optim v1.x (NLSolversBase v7.x): df, fdf, fgh, hv, fghv
99+ # - Optim v2.x (NLSolversBase v8.x): fdf, fgh, hvp, fghvp, fjvp
100+ #
101+ # We store the fields both as symbols (for runtime layout checks) and as `Val`s
102+ # (so the wrapper construction is type-stable and can compile-in the field set).
103+ const _INPLACEOBJECTIVE_SPEC_V8 = (
104+ field_syms= (:fdf , :fgh , :hvp , :fghvp , :fjvp ),
105+ fields= (Val (:fdf ), Val (:fgh ), Val (:hvp ), Val (:fghvp ), Val (:fjvp )),
106+ x_last= (Val (:fdf ), Val (:fgh )),
107+ xv_tail= (Val (:hvp ), Val (:fghvp ), Val (:fjvp )),
108+ )
109+ const _INPLACEOBJECTIVE_SPEC_V7 = (
110+ field_syms= (:df , :fdf , :fgh , :hv , :fghv ),
111+ fields= (Val (:df ), Val (:fdf ), Val (:fgh ), Val (:hv ), Val (:fghv )),
112+ x_last= (Val (:df ), Val (:fdf ), Val (:fgh )),
113+ xv_tail= (Val (:hv ), Val (:fghv )),
114+ )
115+
116+ @inline function _wrap_inplaceobjective_field (
117+ v_field:: Val{field} , f:: NLSolversBase.InplaceObjective , tree:: N , refs, spec
118+ ) where {field,N<: Union{AbstractExpressionNode,AbstractExpression} }
119+ if v_field in spec. x_last
120+ return _wrap_objective_x_last (getfield (f, field), tree, refs)
121+ elseif v_field in spec. xv_tail
122+ return _wrap_objective_xv_tail (getfield (f, field), tree, refs)
123+ else
124+ throw (
125+ ArgumentError (
126+ " Internal error: no wrapping rule for InplaceObjective field $(field) . " *
127+ " Please open an issue at github.com/SymbolicML/DynamicExpressions.jl with your versions." ,
128+ ),
129+ )
130+ end
131+ end
132+
133+ @inline function _wrap_inplaceobjective (
134+ f:: NLSolversBase.InplaceObjective , tree:: N , refs, spec
135+ ) where {N<: Union{AbstractExpressionNode,AbstractExpression} }
136+ wrapped = map (spec. fields) do v_field
137+ _wrap_inplaceobjective_field (v_field, f, tree, refs, spec)
138+ end
139+ return NLSolversBase. InplaceObjective (wrapped... )
140+ end
141+
62142function wrap_func (
63143 f:: NLSolversBase.InplaceObjective , tree:: N , refs
64144) where {N<: Union{AbstractExpressionNode,AbstractExpression} }
65- # Some objectives, like `Optim. only_fg!(fg!)`, are not functions but instead
145+ # Some objectives, like `only_fg!(fg!)`, are not functions but instead
66146 # `InplaceObjective`. These contain multiple functions, each of which needs to be
67147 # wrapped. Some functions are `nothing`; those can be left as-is.
68- @assert fieldnames (NLSolversBase. InplaceObjective) == (:df , :fdf , :fgh , :hv , :fghv )
69- return NLSolversBase. InplaceObjective (
70- wrap_func (f. df, tree, refs),
71- wrap_func (f. fdf, tree, refs),
72- wrap_func (f. fgh, tree, refs),
73- wrap_func (f. hv, tree, refs),
74- wrap_func (f. fghv, tree, refs),
75- )
148+ #
149+ # We use `@static` branching so that only the relevant layout for the *installed*
150+ # NLSolversBase version is compiled/instrumented.
151+ @static if fieldnames (NLSolversBase. InplaceObjective) ==
152+ _INPLACEOBJECTIVE_SPEC_V8. field_syms
153+ # NLSolversBase v8 / Optim v2
154+ return _wrap_inplaceobjective (f, tree, refs, _INPLACEOBJECTIVE_SPEC_V8)
155+ elseif fieldnames (NLSolversBase. InplaceObjective) ==
156+ _INPLACEOBJECTIVE_SPEC_V7. field_syms
157+ # NLSolversBase v7 / Optim v1
158+ return _wrap_inplaceobjective (f, tree, refs, _INPLACEOBJECTIVE_SPEC_V7)
159+ # (Optim < 1 is no longer supported.)
160+ else
161+ # LCOV_EXCL_START
162+ fields = fieldnames (NLSolversBase. InplaceObjective)
163+ throw (
164+ ArgumentError (
165+ " Unsupported NLSolversBase.InplaceObjective field layout: $(fields) . " *
166+ " This extension supports layouts used by NLSolversBase v7 (Optim v1) and v8 (Optim v2). " *
167+ " Please open an issue at github.com/SymbolicML/DynamicExpressions.jl with your versions." ,
168+ ),
169+ )
170+ # LCOV_EXCL_END
171+ end
76172end
77173
78174"""
0 commit comments