Skip to content

Commit edb45ad

Browse files
Clean up test files and package changes
1 parent 051a39e commit edb45ad

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

Project.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1313
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
1414
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
15+
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
16+
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
17+
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
1518
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1619
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1720
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -44,6 +47,7 @@ MLUtils = "0.4"
4447
ModelingToolkit = "10"
4548
Mooncake = "0.4.138"
4649
Optim = ">= 1.4.1"
50+
Optimisers = ">= 0.2.5"
4751
OptimizationBase = "2"
4852
OptimizationMOI = "0.5"
4953
OptimizationOptimJL = "0.4"
@@ -52,7 +56,7 @@ OrdinaryDiffEqTsit5 = "1"
5256
Pkg = "1"
5357
Printf = "1.10"
5458
ProgressLogging = "0.1"
55-
Random = "1.10"
59+
Random = "1.10"
5660
Reexport = "1.2"
5761
ReverseDiff = "1"
5862
SafeTestsets = "0.1"
@@ -64,7 +68,6 @@ Symbolics = "6"
6468
TerminalLoggers = "0.1"
6569
Test = "1.10"
6670
Tracker = "0.2"
67-
Optimisers = ">= 0.2.5"
6871
Zygote = "0.6, 0.7"
6972
julia = "1.10"
7073

@@ -83,6 +86,7 @@ IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
8386
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
8487
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
8588
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
89+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
8690
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
8791
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
8892
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
@@ -100,10 +104,6 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
100104
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
101105
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
102106
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
103-
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
104107

105108
[targets]
106-
test = ["Aqua", "BenchmarkTools", "Boltz", "ComponentArrays", "DiffEqFlux", "Enzyme", "FiniteDiff", "Flux", "ForwardDiff",
107-
"Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers",
108-
"OrdinaryDiffEqTsit5", "Pkg", "Random", "ReverseDiff", "SafeTestsets", "SciMLSensitivity", "SparseArrays", "SparseDiffTools",
109-
"Symbolics", "Test", "Tracker", "Zygote", "Mooncake"]
109+
test = ["Aqua", "BenchmarkTools", "Boltz", "ComponentArrays", "DiffEqFlux", "Enzyme", "FiniteDiff", "Flux", "ForwardDiff", "Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEqTsit5", "Pkg", "Random", "ReverseDiff", "SafeTestsets", "SciMLSensitivity", "SparseArrays", "SparseDiffTools", "Symbolics", "Test", "Tracker", "Zygote", "Mooncake"]

test_sophia_fix.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using Optimization, ComponentArrays, Enzyme
2+
3+
# Test function - simple quadratic
4+
function rosenbrock(x, p)
5+
return (1 - x.a)^2 + 100 * (x.b - x.a^2)^2
6+
end
7+
8+
# Initial parameter as ComponentVector
9+
x0 = ComponentVector(a = 0.0, b = 0.0)
10+
11+
# Create optimization function with Enzyme autodiff
12+
optf = OptimizationFunction(rosenbrock, AutoEnzyme())
13+
14+
# Create optimization problem
15+
prob = OptimizationProblem(optf, x0)
16+
17+
# Test that Sophia optimizer works without shadow generation errors
18+
try
19+
sol = solve(prob, Optimization.Sophia=0.01, k=2), maxiters=5)
20+
println("✓ Sophia optimizer with ComponentArrays succeeded!")
21+
println("Solution: ", sol.u)
22+
println("Final objective: ", sol.objective)
23+
catch e
24+
println("✗ Error: ", e)
25+
rethrow(e)
26+
end

0 commit comments

Comments
 (0)