Skip to content

Commit 1b0c9cd

Browse files
fix: support Optim v2 (NLSolversBase v8) + add CI smoketest
1 parent b03e599 commit 1b0c9cd

File tree

4 files changed

+112
-27
lines changed

4 files changed

+112
-27
lines changed

.github/workflows/CI.yml

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,53 @@ jobs:
110110
version: '1'
111111
- uses: julia-actions/cache@v2
112112
- uses: julia-actions/julia-buildpkg@v1
113-
- name: Pin Optim v1 and run tests
113+
- name: Pin Optim v1 + NLSolversBase v7
114114
run: |
115-
julia --color=yes -e 'import Pkg; Pkg.activate("test"); Pkg.add(Pkg.PackageSpec(name="Optim", version="1")); Pkg.status(["Optim", "NLSolversBase"])'
116-
julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes -e 'import Pkg; Pkg.activate("."); Pkg.test()'
115+
julia --color=yes -e 'import Pkg; Pkg.add("Coverage")'
116+
julia --color=yes -e 'import Pkg; Pkg.activate("."); Pkg.add(Pkg.PackageSpec(name="Optim", version="1")); Pkg.add(Pkg.PackageSpec(name="NLSolversBase", version="7")); Pkg.status(["Optim", "NLSolversBase"])'
117117
shell: bash
118+
- name: Run Optim tests (with coverage)
119+
id: run-tests
120+
run: |
121+
SR_TEST=optim julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user -e 'import Coverage; import Pkg; Pkg.activate("."); Pkg.test(coverage=true)'
122+
julia --color=yes coverage.jl
123+
shell: bash
124+
- name: Upload coverage artifact
125+
if: steps.run-tests.outcome == 'success'
126+
uses: actions/upload-artifact@v6
127+
with:
128+
name: coverage-optim-v1-${{ runner.os }}-julia-1
129+
path: lcov.info
130+
131+
132+
optim_v2_smoketest:
133+
name: Optim v2 (NLSolversBase v8) - ubuntu-latest
134+
runs-on: ubuntu-latest
135+
timeout-minutes: 60
136+
steps:
137+
- uses: actions/checkout@v4
138+
- uses: julia-actions/setup-julia@v2
139+
with:
140+
version: '1'
141+
- uses: julia-actions/cache@v2
142+
- uses: julia-actions/julia-buildpkg@v1
143+
- name: Pin Optim v2 + NLSolversBase v8
144+
run: |
145+
julia --color=yes -e 'import Pkg; Pkg.add("Coverage")'
146+
julia --color=yes -e 'import Pkg; Pkg.activate("."); Pkg.add(Pkg.PackageSpec(name="Optim", version="2")); Pkg.add(Pkg.PackageSpec(name="NLSolversBase", version="8")); Pkg.status(["Optim", "NLSolversBase"])'
147+
shell: bash
148+
- name: Run Optim tests (with coverage)
149+
id: run-tests
150+
run: |
151+
SR_TEST=optim julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user -e 'import Coverage; import Pkg; Pkg.activate("."); Pkg.test(coverage=true)'
152+
julia --color=yes coverage.jl
153+
shell: bash
154+
- name: Upload coverage artifact
155+
if: steps.run-tests.outcome == 'success'
156+
uses: actions/upload-artifact@v6
157+
with:
158+
name: coverage-optim-v2-${{ runner.os }}-julia-1
159+
path: lcov.info
118160

119161

120162
codecov:
@@ -123,6 +165,8 @@ jobs:
123165
needs:
124166
- test
125167
- additional_tests
168+
- optim_v1_smoketest
169+
- optim_v2_smoketest
126170
steps:
127171
# Codecov uploader expects a git checkout (commit metadata + repo root)
128172
- uses: actions/checkout@v4

ext/DynamicExpressionsOptimExt.jl

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -92,25 +92,33 @@ function wrap_func(
9292
end
9393

9494
# `NLSolversBase.InplaceObjective` is an internal type whose field layout changed
95-
# between NLSolversBase versions.
95+
# between NLSolversBase versions (and therefore between Optim majors).
9696
#
97-
# - NLSolversBase v7 (Optim v1.x): df, fdf, fgh, hv, fghv
98-
# - NLSolversBase v8 (Optim v2.x): fdf, fgh, hvp, fghvp, fjvp
97+
# This extension supports:
98+
# - Optim v1.x (NLSolversBase v7.x): df, fdf, fgh, hv, fghv
99+
# - Optim v2.x (NLSolversBase v8.x): fdf, fgh, hvp, fghvp, fjvp
100+
#
101+
# We store the fields both as symbols (for runtime layout checks) and as `Val`s
102+
# (so the wrapper construction is type-stable and can compile-in the field set).
99103
const _INPLACEOBJECTIVE_SPEC_V8 = (
100-
fields=(:fdf, :fgh, :hvp, :fghvp, :fjvp),
101-
x_last=(:fdf, :fgh),
102-
xv_tail=(:hvp, :fghvp, :fjvp),
104+
field_syms = (:fdf, :fgh, :hvp, :fghvp, :fjvp),
105+
fields = (Val(:fdf), Val(:fgh), Val(:hvp), Val(:fghvp), Val(:fjvp)),
106+
x_last = (Val(:fdf), Val(:fgh)),
107+
xv_tail = (Val(:hvp), Val(:fghvp), Val(:fjvp)),
103108
)
104109
const _INPLACEOBJECTIVE_SPEC_V7 = (
105-
fields=(:df, :fdf, :fgh, :hv, :fghv), x_last=(:df, :fdf, :fgh), xv_tail=(:hv, :fghv)
110+
field_syms = (:df, :fdf, :fgh, :hv, :fghv),
111+
fields = (Val(:df), Val(:fdf), Val(:fgh), Val(:hv), Val(:fghv)),
112+
x_last = (Val(:df), Val(:fdf), Val(:fgh)),
113+
xv_tail = (Val(:hv), Val(:fghv)),
106114
)
107115

108116
@inline function _wrap_inplaceobjective_field(
109-
::Val{field}, f::NLSolversBase.InplaceObjective, tree::N, refs, spec
117+
v_field::Val{field}, f::NLSolversBase.InplaceObjective, tree::N, refs, spec
110118
) where {field,N<:Union{AbstractExpressionNode,AbstractExpression}}
111-
if field in spec.x_last
119+
if v_field in spec.x_last
112120
return _wrap_objective_x_last(getfield(f, field), tree, refs)
113-
elseif field in spec.xv_tail
121+
elseif v_field in spec.xv_tail
114122
return _wrap_objective_xv_tail(getfield(f, field), tree, refs)
115123
else
116124
throw(
@@ -125,8 +133,8 @@ end
125133
@inline function _wrap_inplaceobjective(
126134
f::NLSolversBase.InplaceObjective, tree::N, refs, spec
127135
) where {N<:Union{AbstractExpressionNode,AbstractExpression}}
128-
wrapped = map(spec.fields) do field
129-
_wrap_inplaceobjective_field(Val(field), f, tree, refs, spec)
136+
wrapped = map(spec.fields) do v_field
137+
_wrap_inplaceobjective_field(v_field, f, tree, refs, spec)
130138
end
131139
return NLSolversBase.InplaceObjective(wrapped...)
132140
end
@@ -141,14 +149,15 @@ function wrap_func(
141149
# We use `@static` branching so that only the relevant layout for the *installed*
142150
# NLSolversBase version is compiled/instrumented.
143151
@static if fieldnames(NLSolversBase.InplaceObjective) ==
144-
_INPLACEOBJECTIVE_SPEC_V8.fields
152+
_INPLACEOBJECTIVE_SPEC_V8.field_syms
145153
# NLSolversBase v8 / Optim v2
146154
return _wrap_inplaceobjective(f, tree, refs, _INPLACEOBJECTIVE_SPEC_V8)
147-
elseif fieldnames(NLSolversBase.InplaceObjective) == _INPLACEOBJECTIVE_SPEC_V7.fields
155+
elseif fieldnames(NLSolversBase.InplaceObjective) == _INPLACEOBJECTIVE_SPEC_V7.field_syms
148156
# NLSolversBase v7 / Optim v1
149157
return _wrap_inplaceobjective(f, tree, refs, _INPLACEOBJECTIVE_SPEC_V7)
150158
# (Optim < 1 is no longer supported.)
151159
else
160+
# LCOV_EXCL_START
152161
fields = fieldnames(NLSolversBase.InplaceObjective)
153162
throw(
154163
ArgumentError(
@@ -157,6 +166,7 @@ function wrap_func(
157166
"Please open an issue at github.com/SymbolicML/DynamicExpressions.jl with your versions.",
158167
),
159168
)
169+
# LCOV_EXCL_END
160170
end
161171
end
162172

test/runtests.jl

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
11
using SafeTestsets
22
using TestItemRunner
33

4-
# Check if SR_ENZYME_TEST is set in env
5-
test_name = split(get(ENV, "SR_TEST", "main"), ",")
4+
# Control which test groups run.
5+
#
6+
# Accepts a comma-separated list in SR_TEST (default: "main").
7+
#
8+
# - "main": full test suite (testitems)
9+
# - "optim": Optim-specific testitems only
10+
# - "narity": n-ary testitems only
11+
# - "jet": JET analysis
12+
# - "enzyme": Enzyme tests
613

7-
unknown_tests = filter(Base.Fix2(, ["enzyme", "jet", "main", "narity"]), test_name)
14+
test_names = split(get(ENV, "SR_TEST", "main"), ",")
15+
16+
allowed = ["enzyme", "jet", "main", "narity", "optim"]
17+
unknown_tests = filter(Base.Fix2(, allowed), test_names)
818

919
if !isempty(unknown_tests)
1020
error("Unknown test names: $unknown_tests")
1121
end
1222

13-
if "enzyme" in test_name
23+
if "enzyme" in test_names
1424
@safetestset "Test enzyme derivatives" begin
1525
include("test_enzyme.jl")
1626
end
1727
end
18-
if "jet" in test_name
28+
29+
if "jet" in test_names
1930
@safetestset "JET" begin
2031
using Preferences
2132
set_preferences!(
@@ -61,7 +72,25 @@ if "jet" in test_name
6172
end
6273
end
6374
end
64-
if "main" in test_name
65-
include("unittest.jl")
75+
76+
# Collect testitem definitions to load, then run them once.
77+
testitem_files = String[]
78+
79+
if "main" in test_names
80+
push!(testitem_files, "unittest.jl")
81+
else
82+
if "optim" in test_names
83+
push!(testitem_files, "test_optim.jl")
84+
end
85+
if "narity" in test_names
86+
push!(testitem_files, "test_n_arity_nodes.jl")
87+
end
88+
end
89+
90+
for f in testitem_files
91+
include(f)
92+
end
93+
94+
if !isempty(testitem_files)
6695
@run_package_tests
6796
end

test/test_optim.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,14 @@ end
222222
dummy = (args...) -> nothing
223223
obj = Optim.NLSolversBase.InplaceObjective((dummy for _ in fields)...)
224224

225-
spec = if fields == ext._INPLACEOBJECTIVE_SPEC_V8.fields
225+
spec = if fields == ext._INPLACEOBJECTIVE_SPEC_V8.field_syms
226226
ext._INPLACEOBJECTIVE_SPEC_V8
227-
elseif fields == ext._INPLACEOBJECTIVE_SPEC_V7.fields
227+
elseif fields == ext._INPLACEOBJECTIVE_SPEC_V7.field_syms
228228
ext._INPLACEOBJECTIVE_SPEC_V7
229229
else
230-
ext._INPLACEOBJECTIVE_SPEC_OLD
230+
# Should be unreachable for supported Optim majors, but pick one so the
231+
# test still compiles if a future layout appears.
232+
ext._INPLACEOBJECTIVE_SPEC_V8
231233
end
232234

233235
@test_throws ArgumentError ext._wrap_inplaceobjective_field(

0 commit comments

Comments
 (0)