Skip to content

Commit af07d62

Browse files
authored
Merge pull request #37 from LAMPSPUC/dev
Version 0.1.4
2 parents c9df846 + e343055 commit af07d62

File tree

6 files changed

+71
-4
lines changed

6 files changed

+71
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ApplicationDrivenLearning"
22
uuid = "0856f1c8-ef17-4e14-9230-2773e47a789e"
33
authors = ["Giovanni Amorim", "Joaquim Garcia"]
4-
version = "0.1.3"
4+
version = "0.1.4"
55

66
[deps]
77
BilevelJuMP = "485130c0-026e-11ea-0f1a-6992cd14145c"

src/optimizers/gradient.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ function train_with_gradient!(
3434
verbose = get(params, :verbose, true)
3535
compute_cost_every = get(params, :compute_cost_every, 1)
3636
time_limit = get(params, :time_limit, Inf)
37+
g_tol = get(params, :g_tol, 0)
3738

3839
# init parameters
3940
start_time = time()
@@ -87,6 +88,17 @@ function train_with_gradient!(
8788

8889
# check time limit reach
8990
if time() - start_time > time_limit
91+
if verbose
92+
println("Time limit reached.")
93+
end
94+
break
95+
end
96+
97+
# check gradient tolerance
98+
if maximum(abs.(dC)) < g_tol
99+
if verbose
100+
println("Gradient tolerance reached.")
101+
end
90102
break
91103
end
92104

src/optimizers/gradient_mpi.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@ function train_with_gradient_mpi!(
1717
compute_cost_every = get(params, :compute_cost_every, 1)
1818
mpi_finalize = get(params, :mpi_finalize, true)
1919
time_limit = get(params, :time_limit, Inf)
20+
g_tol = get(params, :g_tol, 0)
2021

2122
JQM.mpi_init()
2223

2324
# init parameters
2425
start_time = time()
2526
is_done = false
2627
best_C = Inf
27-
best_θ = []
28+
best_θ = extract_params(model.forecast)
2829
curr_C = 0.0
2930
trace = Array{Float64}(undef, epochs)
3031
dCdz = Vector{Float32}(undef, size(model.policy_vars, 1))
@@ -121,6 +122,17 @@ function train_with_gradient_mpi!(
121122

122123
# check time limit reach
123124
if time() - start_time > time_limit
125+
if verbose
126+
println("Time limit reached.")
127+
end
128+
break
129+
end
130+
131+
# check gradient tolerance
132+
if maximum(abs.(dCdy)) < g_tol
133+
if verbose
134+
println("Gradient tolerance reached.")
135+
end
124136
break
125137
end
126138

src/predictive_model.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ When only one network is passed as a Chain object, input and output
6363
indexes are directly extracted.
6464
"""
6565
function PredictiveModel(network::Flux.Chain)
66-
input_size = size(network[1].weight)[2]
67-
output_size = size(network[end].weight)[1]
66+
param_layers = [layer for layer in network if has_params(layer)]
67+
input_size = size(param_layers[1].weight, 2)
68+
output_size = size(param_layers[end].weight, 1)
6869
input_output_map = [Dict(collect(1:input_size) => collect(1:output_size))]
6970
return PredictiveModel(
7071
[deepcopy(network)],

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ Random.seed!(123)
1414
include("utils.jl")
1515
include("test_predictive_model.jl")
1616
include("test_newsvendor.jl")
17+
include("test_gradient.jl")

test/test_gradient.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# basic model for testing gradient mode
2+
X = Float32.(ones(1, 1))
3+
Y = Float32.(ones(1, 1))
4+
model = ApplicationDrivenLearning.Model()
5+
@variables(model, begin
6+
x >= 0, ApplicationDrivenLearning.Policy
7+
d, ApplicationDrivenLearning.Forecast
8+
end)
9+
@objective(ApplicationDrivenLearning.Plan(model), Min, x.plan)
10+
@objective(ApplicationDrivenLearning.Assess(model), Min, x.assess)
11+
set_optimizer(model, HiGHS.Optimizer)
12+
set_silent(model)
13+
ApplicationDrivenLearning.set_forecast_model(model, Chain(Dense(1 => 1)))
14+
15+
@testset "GradientMode Stop Rules" begin
16+
# epochs
17+
initial_sol = ApplicationDrivenLearning.extract_params(model.forecast)
18+
opt = ApplicationDrivenLearning.Options(
19+
ApplicationDrivenLearning.GradientMode,
20+
epochs = 0,
21+
)
22+
sol = ApplicationDrivenLearning.train!(model, X, Y, opt)
23+
@test initial_sol == sol.params
24+
25+
# time_limit
26+
opt = ApplicationDrivenLearning.Options(
27+
ApplicationDrivenLearning.GradientMode,
28+
time_limit = 0,
29+
)
30+
sol = ApplicationDrivenLearning.train!(model, X, Y, opt)
31+
@test initial_sol == sol.params
32+
33+
# gradient norm
34+
initial_sol = ApplicationDrivenLearning.extract_params(model.forecast)
35+
opt = ApplicationDrivenLearning.Options(
36+
ApplicationDrivenLearning.GradientMode,
37+
g_tol = Inf,
38+
)
39+
sol = ApplicationDrivenLearning.train!(model, X, Y, opt)
40+
@test initial_sol == sol.params
41+
end

0 commit comments

Comments
 (0)