Skip to content

Commit a25c1a7

Browse files
committed
add gradient mode test
1 parent b3dab96 commit a25c1a7

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ Random.seed!(123)
1414
include("utils.jl")
1515
include("test_predictive_model.jl")
1616
include("test_newsvendor.jl")
17+
include("test_gradient.jl")

test/test_gradient.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# basic model for testing gradient mode
2+
X = Float32.(ones(1, 1))
3+
Y = Float32.(ones(1, 1))
4+
model = ApplicationDrivenLearning.Model()
5+
@variables(model, begin
6+
x >= 0, ApplicationDrivenLearning.Policy
7+
d, ApplicationDrivenLearning.Forecast
8+
end)
9+
@objective(ApplicationDrivenLearning.Plan(model), Min, x.plan)
10+
@objective(ApplicationDrivenLearning.Assess(model), Min, x.assess)
11+
set_optimizer(model, HiGHS.Optimizer)
12+
set_silent(model)
13+
ApplicationDrivenLearning.set_forecast_model(model, Chain(Dense(1 => 1)))
14+
15+
@testset "GradientMode Stop Rules" begin
16+
# epochs
17+
opt = ApplicationDrivenLearning.Options(
18+
ApplicationDrivenLearning.GradientMode,
19+
epochs = 0,
20+
)
21+
sol = ApplicationDrivenLearning.train!(model, X, Y, opt)
22+
@test initial_sol == sol.params
23+
24+
# time_limit
25+
initial_sol = ApplicationDrivenLearning.extract_params(model.forecast)
26+
opt = ApplicationDrivenLearning.Options(
27+
ApplicationDrivenLearning.GradientMode,
28+
time_limit = 0,
29+
)
30+
sol = ApplicationDrivenLearning.train!(model, X, Y, opt)
31+
@test initial_sol == sol.params
32+
33+
# gradient norm
34+
initial_sol = ApplicationDrivenLearning.extract_params(model.forecast)
35+
opt = ApplicationDrivenLearning.Options(
36+
ApplicationDrivenLearning.GradientMode,
37+
g_tol = Inf,
38+
)
39+
sol = ApplicationDrivenLearning.train!(model, X, Y, opt)
40+
@test initial_sol == sol.params
41+
end

0 commit comments

Comments
 (0)