Skip to content

Commit 36d7654

Browse files
Added OptimizationAuglag tests
1 parent 1429b2e commit 36d7654

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

lib/OptimizationAuglag/Project.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,17 @@ name = "OptimizationAuglag"
22
uuid = "2ea93f80-9333-43a1-a68d-1f53b957a421"
33
authors = ["paramthakkar123 <[email protected]>"]
44
version = "0.1.0"
5+
6+
[deps]
7+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
9+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
10+
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
11+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
12+
13+
[compat]
14+
ForwardDiff = "1.0.1"
15+
MLUtils = "0.4.8"
16+
Optimization = "4.4.0"
17+
OptimizationOptimisers = "0.3.8"
18+
Test = "1.11.0"

lib/OptimizationAuglag/src/OptimizationAuglag.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module OptimizationAuglag
22

3+
using Optimization.SciMLBase, Optimization
4+
35
@kwdef struct AugLag
46
inner::Any
57
τ = 0.5
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using MLUtils, OptimizationOptimisers
2+
using ForwardDiff
3+
using Test
4+
5+
@testset "OptimizationAuglag.jl" begin
6+
x0 = (-pi):0.001:pi
7+
y0 = sin.(x0)
8+
data = MLUtils.DataLoader((x0, y0), batchsize = 126)
9+
10+
function loss(coeffs, data)
11+
ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])]
12+
return sum(abs2, ypred .- data[2])
13+
end
14+
15+
function cons1(res, coeffs, p = nothing)
16+
res[1] = coeffs[1] * coeffs[5] - 1
17+
return nothing
18+
end
19+
20+
optf = OptimizationFunction(loss, AutoSparseForwardDiff(), cons = cons1)
21+
callback = (st, l) -> (@show l; return false)
22+
23+
initpars = rand(5)
24+
l0 = optf(initpars, (x0, y0))
25+
26+
prob = OptimizationProblem(optf, initpars, data, lcons = [-Inf], ucons = [1],
27+
lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0])
28+
opt = solve(
29+
prob, Optimization.AugLag(; inner = Adam()), maxiters = 10000, callback = callback)
30+
@test opt.objective < l0
31+
end

0 commit comments

Comments
 (0)