Skip to content

Commit f9bc4a6

Browse files
SebastianM-CMilesCranmerBotMilesCranmerclaude
authored
deps: add Optim v2 support (#159)
* Compat: support Optim v2 * fix: wrap InplaceObjective hvp/fghvp/fjvp argument order * fix: explicit InplaceObjective wrapping * fix: support Optim v1 InplaceObjective * test: cover objective arity error; @static InplaceObjective layout * style: run JuliaFormatter * Refactor InplaceObjective wrapping * test: cover Optim extension error path * fix: drop Optim <1; add NLSolversBase compat + Optim v1 CI * style: run JuliaFormatter * fix: support Optim v2 (NLSolversBase v8) + add CI smoketest * style: format Optim extension * ci: simplify Optim CI and avoid Optim-internal only_fg! * fix(tests): filter testitems by SR_TEST in runtests * style: format test runtests * fix: initialize variable-node val to avoid NaNs * Revert "fix: initialize variable-node val to avoid NaNs" This reverts commit 10b59c2. * fix(tests): include test_optim.jl in "main" test group The runtests.jl refactor that added SR_TEST filtering accidentally excluded test_optim.jl from the "main" group. This meant the regular CI (which resolves Optim v2) ran zero Optim extension tests — only the dedicated v1 smoketest exercised the extension. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: MilesCranmerBot <miles.cranmer.bot@gmail.com> Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 717146e commit f9bc4a6

File tree

6 files changed

+321
-31
lines changed

6 files changed

+321
-31
lines changed

.github/workflows/CI.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,43 @@ jobs:
9999
path: lcov.info
100100

101101

102+
optim_v1_smoketest:
103+
name: Optim v1 (NLSolversBase v7) - ubuntu-latest
104+
runs-on: ubuntu-latest
105+
timeout-minutes: 60
106+
steps:
107+
- uses: actions/checkout@v4
108+
- uses: julia-actions/setup-julia@v2
109+
with:
110+
version: '1'
111+
- uses: julia-actions/cache@v2
112+
- uses: julia-actions/julia-buildpkg@v1
113+
- name: Pin Optim v1 + NLSolversBase v7
114+
run: |
115+
julia --color=yes -e 'import Pkg; Pkg.add("Coverage")'
116+
julia --color=yes -e 'import Pkg; Pkg.activate("."); Pkg.add(Pkg.PackageSpec(name="Optim", version="1")); Pkg.add(Pkg.PackageSpec(name="NLSolversBase", version="7")); Pkg.status(["Optim", "NLSolversBase"])'
117+
shell: bash
118+
- name: Run Optim tests (with coverage)
119+
id: run-tests
120+
run: |
121+
SR_TEST=optim julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user -e 'import Coverage; import Pkg; Pkg.activate("."); Pkg.test(coverage=true)'
122+
julia --color=yes coverage.jl
123+
shell: bash
124+
- name: Upload coverage artifact
125+
if: steps.run-tests.outcome == 'success'
126+
uses: actions/upload-artifact@v6
127+
with:
128+
name: coverage-optim-v1-${{ runner.os }}-julia-1
129+
path: lcov.info
130+
131+
102132
codecov:
103133
name: Upload combined coverage to Codecov
104134
runs-on: ubuntu-latest
105135
needs:
106136
- test
107137
- additional_tests
138+
- optim_v1_smoketest
108139
steps:
109140
# Codecov uploader expects a git checkout (commit metadata + repo root)
110141
- uses: actions/checkout@v4

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1818
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
1919
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
2020
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
21+
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
2122
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
2223
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2324

2425
[extensions]
2526
DynamicExpressionsBumperExt = "Bumper"
2627
DynamicExpressionsLoopVectorizationExt = "LoopVectorization"
27-
DynamicExpressionsOptimExt = "Optim"
28+
DynamicExpressionsOptimExt = ["Optim", "NLSolversBase"]
2829
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
2930
DynamicExpressionsZygoteExt = "Zygote"
3031

@@ -36,7 +37,8 @@ DispatchDoctor = "0.4"
3637
Interfaces = "0.3"
3738
LoopVectorization = "0.12"
3839
MacroTools = "0.4, 0.5"
39-
Optim = "0.19, 1"
40+
Optim = "1, 2"
41+
NLSolversBase = "7, 8"
4042
PrecompileTools = "1"
4143
Reexport = "1"
4244
SymbolicUtils = "4"
@@ -49,5 +51,6 @@ TOML = "1"
4951
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
5052
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
5153
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
54+
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
5255
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
5356
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

ext/DynamicExpressionsOptimExt.jl

Lines changed: 115 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ using DynamicExpressions:
99
set_scalar_constants!,
1010
get_number_type
1111

12-
import Optim: Optim, OptimizationResults, NLSolversBase
12+
import Optim: Optim, OptimizationResults
13+
using NLSolversBase: NLSolversBase
1314

1415
#! format: off
1516
"""
@@ -38,41 +39,136 @@ function Optim.minimizer(r::ExpressionOptimizationResults)
3839
end
3940

4041
"""Wrap function or objective with insertion of values of the constant nodes."""
41-
function wrap_func(
42+
@inline function _wrap_objective_x_last(
43+
::Nothing, tree::N, refs
44+
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
45+
return nothing
46+
end
47+
@inline function _wrap_objective_x_last(
4248
f::F, tree::N, refs
4349
) where {F<:Function,T,N<:Union{AbstractExpressionNode{T},AbstractExpression{T}}}
4450
function wrapped_f(args::Vararg{Any,M}) where {M}
45-
first_args = args[begin:(end - 1)]
46-
x = args[end]
51+
x = args[M]
4752
set_scalar_constants!(tree, x, refs)
48-
return @inline(f(first_args..., tree))
53+
newargs = Base.setindex(args, tree, M)
54+
return @inline(f(newargs...))
4955
end
50-
# without first args, it looks like this
51-
# function wrapped_f(x)
52-
# set_scalar_constants!(tree, x, refs)
53-
# return @inline(f(tree))
54-
# end
5556
return wrapped_f
5657
end
58+
59+
@inline function _wrap_objective_xv_tail(
60+
::Nothing, tree::N, refs
61+
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
62+
return nothing
63+
end
64+
@inline function _wrap_objective_xv_tail(
65+
f::F, tree::N, refs
66+
) where {F<:Function,T,N<:Union{AbstractExpressionNode{T},AbstractExpression{T}}}
67+
function wrapped_f(args::Vararg{Any,M}) where {M}
68+
if M < 2
69+
throw(
70+
ArgumentError(
71+
"Expected at least 2 arguments for objective functions of the form (..., x, v).",
72+
),
73+
)
74+
end
75+
x = args[M - 1]
76+
set_scalar_constants!(tree, x, refs)
77+
newargs = Base.setindex(args, tree, M - 1)
78+
return @inline(f(newargs...))
79+
end
80+
return wrapped_f
81+
end
82+
83+
function wrap_func(
84+
f::F, tree::N, refs
85+
) where {F<:Function,T,N<:Union{AbstractExpressionNode{T},AbstractExpression{T}}}
86+
return _wrap_objective_x_last(f, tree, refs)
87+
end
5788
function wrap_func(
5889
::Nothing, tree::N, refs
5990
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
6091
return nothing
6192
end
93+
94+
# `NLSolversBase.InplaceObjective` is an internal type whose field layout changed
95+
# between NLSolversBase versions (and therefore between Optim majors).
96+
#
97+
# This extension supports:
98+
# - Optim v1.x (NLSolversBase v7.x): df, fdf, fgh, hv, fghv
99+
# - Optim v2.x (NLSolversBase v8.x): fdf, fgh, hvp, fghvp, fjvp
100+
#
101+
# We store the fields both as symbols (for runtime layout checks) and as `Val`s
102+
# (so the wrapper construction is type-stable and can compile-in the field set).
103+
const _INPLACEOBJECTIVE_SPEC_V8 = (
104+
field_syms=(:fdf, :fgh, :hvp, :fghvp, :fjvp),
105+
fields=(Val(:fdf), Val(:fgh), Val(:hvp), Val(:fghvp), Val(:fjvp)),
106+
x_last=(Val(:fdf), Val(:fgh)),
107+
xv_tail=(Val(:hvp), Val(:fghvp), Val(:fjvp)),
108+
)
109+
const _INPLACEOBJECTIVE_SPEC_V7 = (
110+
field_syms=(:df, :fdf, :fgh, :hv, :fghv),
111+
fields=(Val(:df), Val(:fdf), Val(:fgh), Val(:hv), Val(:fghv)),
112+
x_last=(Val(:df), Val(:fdf), Val(:fgh)),
113+
xv_tail=(Val(:hv), Val(:fghv)),
114+
)
115+
116+
@inline function _wrap_inplaceobjective_field(
117+
v_field::Val{field}, f::NLSolversBase.InplaceObjective, tree::N, refs, spec
118+
) where {field,N<:Union{AbstractExpressionNode,AbstractExpression}}
119+
if v_field in spec.x_last
120+
return _wrap_objective_x_last(getfield(f, field), tree, refs)
121+
elseif v_field in spec.xv_tail
122+
return _wrap_objective_xv_tail(getfield(f, field), tree, refs)
123+
else
124+
throw(
125+
ArgumentError(
126+
"Internal error: no wrapping rule for InplaceObjective field $(field). " *
127+
"Please open an issue at github.com/SymbolicML/DynamicExpressions.jl with your versions.",
128+
),
129+
)
130+
end
131+
end
132+
133+
@inline function _wrap_inplaceobjective(
134+
f::NLSolversBase.InplaceObjective, tree::N, refs, spec
135+
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
136+
wrapped = map(spec.fields) do v_field
137+
_wrap_inplaceobjective_field(v_field, f, tree, refs, spec)
138+
end
139+
return NLSolversBase.InplaceObjective(wrapped...)
140+
end
141+
62142
function wrap_func(
63143
f::NLSolversBase.InplaceObjective, tree::N, refs
64144
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
65-
# Some objectives, like `Optim.only_fg!(fg!)`, are not functions but instead
145+
# Some objectives, like `only_fg!(fg!)`, are not functions but instead
66146
# `InplaceObjective`. These contain multiple functions, each of which needs to be
67147
# wrapped. Some functions are `nothing`; those can be left as-is.
68-
@assert fieldnames(NLSolversBase.InplaceObjective) == (:df, :fdf, :fgh, :hv, :fghv)
69-
return NLSolversBase.InplaceObjective(
70-
wrap_func(f.df, tree, refs),
71-
wrap_func(f.fdf, tree, refs),
72-
wrap_func(f.fgh, tree, refs),
73-
wrap_func(f.hv, tree, refs),
74-
wrap_func(f.fghv, tree, refs),
75-
)
148+
#
149+
# We use `@static` branching so that only the relevant layout for the *installed*
150+
# NLSolversBase version is compiled/instrumented.
151+
@static if fieldnames(NLSolversBase.InplaceObjective) ==
152+
_INPLACEOBJECTIVE_SPEC_V8.field_syms
153+
# NLSolversBase v8 / Optim v2
154+
return _wrap_inplaceobjective(f, tree, refs, _INPLACEOBJECTIVE_SPEC_V8)
155+
elseif fieldnames(NLSolversBase.InplaceObjective) ==
156+
_INPLACEOBJECTIVE_SPEC_V7.field_syms
157+
# NLSolversBase v7 / Optim v1
158+
return _wrap_inplaceobjective(f, tree, refs, _INPLACEOBJECTIVE_SPEC_V7)
159+
# (Optim < 1 is no longer supported.)
160+
else
161+
# LCOV_EXCL_START
162+
fields = fieldnames(NLSolversBase.InplaceObjective)
163+
throw(
164+
ArgumentError(
165+
"Unsupported NLSolversBase.InplaceObjective field layout: $(fields). " *
166+
"This extension supports layouts used by NLSolversBase v7 (Optim v1) and v8 (Optim v2). " *
167+
"Please open an issue at github.com/SymbolicML/DynamicExpressions.jl with your versions.",
168+
),
169+
)
170+
# LCOV_EXCL_END
171+
end
76172
end
77173

78174
"""

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1313
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
14+
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
1415
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1516
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1617
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"

test/runtests.jl

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
11
using SafeTestsets
22
using TestItemRunner
33

4-
# Check if SR_ENZYME_TEST is set in env
5-
test_name = split(get(ENV, "SR_TEST", "main"), ",")
4+
# Control which test groups run.
5+
#
6+
# Accepts a comma-separated list in SR_TEST (default: "main").
7+
#
8+
# - "main": full test suite (testitems)
9+
# - "optim": Optim-specific testitems only
10+
# - "narity": n-ary testitems only
11+
# - "jet": JET analysis
12+
# - "enzyme": Enzyme tests
613

7-
unknown_tests = filter(Base.Fix2(, ["enzyme", "jet", "main", "narity"]), test_name)
14+
test_names = split(get(ENV, "SR_TEST", "main"), ",")
15+
16+
allowed = ["enzyme", "jet", "main", "narity", "optim"]
17+
unknown_tests = filter(Base.Fix2(, allowed), test_names)
818

919
if !isempty(unknown_tests)
1020
error("Unknown test names: $unknown_tests")
1121
end
1222

13-
if "enzyme" in test_name
23+
if "enzyme" in test_names
1424
@safetestset "Test enzyme derivatives" begin
1525
include("test_enzyme.jl")
1626
end
1727
end
18-
if "jet" in test_name
28+
29+
if "jet" in test_names
1930
@safetestset "JET" begin
2031
using Preferences
2132
set_preferences!(
@@ -61,7 +72,25 @@ if "jet" in test_name
6172
end
6273
end
6374
end
64-
if "main" in test_name
65-
include("unittest.jl")
66-
@run_package_tests
75+
76+
# TestItemRunner's `@run_package_tests` scans *all* `.jl` files under the package root,
77+
# so we must filter to only the testitem files we actually want to run.
78+
# (Simply `include(...)`-ing a subset of files is not sufficient.)
79+
80+
testitem_suffixes = String[]
81+
82+
if "main" in test_names
83+
push!(testitem_suffixes, joinpath("test", "unittest.jl"))
84+
push!(testitem_suffixes, joinpath("test", "test_optim.jl"))
85+
end
86+
if "optim" in test_names
87+
push!(testitem_suffixes, joinpath("test", "test_optim.jl"))
88+
end
89+
if "narity" in test_names
90+
push!(testitem_suffixes, joinpath("test", "test_n_arity_nodes.jl"))
91+
end
92+
93+
if !isempty(testitem_suffixes)
94+
@run_package_tests filter =
95+
ti -> any(suf -> endswith(ti.filename, suf), testitem_suffixes)
6796
end

0 commit comments

Comments
 (0)