Skip to content

Commit f31801e

Browse files
authored
Merge branch 'main' into compathelper/new_version/2025-03-12-01-17-18-775-00042309767
2 parents 61a8365 + fadd140 commit f31801e

11 files changed

+78
-50
lines changed

Project.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ApplicationDrivenLearning"
22
uuid = "0856f1c8-ef17-4e14-9230-2773e47a789e"
3-
authors = ["Giovanni Amorin"]
4-
version = "0.1.0"
3+
authors = ["Giovanni Amorim", "Joaquim Garcia"]
4+
version = "0.1.1"
55

66
[deps]
77
BilevelJuMP = "485130c0-026e-11ea-0f1a-6992cd14145c"
@@ -11,18 +11,20 @@ JobQueueMPI = "32d208e1-246e-420c-b6ff-18b71b410923"
1111
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
1212
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1313
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
14+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1415
ParametricOptInterface = "0ce4ce61-57bf-432b-a095-efac525d185e"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

1819
[compat]
1920
BilevelJuMP = "0.6.2"
2021
DiffOpt = "0.5.0"
21-
Flux = "0.14.25"
22+
Flux = "0.16.3"
2223
JobQueueMPI = "0.1.1"
2324
JuMP = "1.24"
2425
MPI = "0.20.22"
2526
Optim = "1.11"
26-
ParametricOptInterface = "0.9.0"
27+
Optimisers = "0.4.5"
28+
ParametricOptInterface = "0.9.0, 0.10"
2729
Zygote = "0.6.75, 0.7"
2830
julia = "1.10"

src/ApplicationDrivenLearning.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,10 @@ and new constraint fixing to original forecast variables.
135135
function build_plan_model_forecast_params(model::Model)
136136
# adds parametrized forecast variables using MOI.Parameter
137137
forecast_size = size(model.forecast_vars)[1]
138-
model.plan_forecast_params = @variable(
139-
model.plan,
140-
_forecast[1:forecast_size] in MOI.Parameter.(zeros(forecast_size))
141-
)
142-
# fixes old and new prediction variables together
138+
model.plan_forecast_params = plan_forecast_vars(model)
143139
@constraint(
144140
model.plan,
145-
plan_forecast_fix,
146-
model.plan_forecast_params .== plan_forecast_vars(model)
141+
model.plan_forecast_params .∈ MOI.Parameter.(zeros(forecast_size))
147142
)
148143
end
149144

src/flux_utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Extract the parameters of a Flux model (Flux.Chain or Flux.Dense) into a single
77
vector.
88
"""
99
function extract_flux_params(model::Union{Flux.Chain,Flux.Dense})
10-
θ = Flux.params(model)
10+
θ = Flux.trainables(model)
1111
return reduce(vcat, [vec(p) for p in θ])
1212
end
1313

@@ -21,7 +21,7 @@ function fix_flux_params_single_model(
2121
θ::Vector{<:Real},
2222
)
2323
i = 1
24-
for p in Flux.params(model)
24+
for p in Flux.trainables(model)
2525
psize = prod(size(p))
2626
p .= reshape(θ[i:i+psize-1], size(p))
2727
i += psize
@@ -38,7 +38,7 @@ of parameters.
3838
function fix_flux_params_multi_model(models, θ::Vector{<:Real})
3939
i = 1
4040
for model in models
41-
for p in Flux.params(model)
41+
for p in Flux.trainables(model)
4242
psize = prod(size(p))
4343
p .= reshape(θ[i:i+psize-1], size(p))
4444
i += psize
@@ -54,8 +54,8 @@ Check if a Flux layer has parameters.
5454
"""
5555
function has_params(layer)
5656
try
57-
# Attempt to get parameters; if it works and isn't empty, return true
58-
return !isempty(Flux.params(layer))
57+
# Attempt to get trainable parameters; if it works and isn't empty, return true
58+
return !isempty(Flux.trainable(layer))
5959
catch e
6060
# If there is an error (e.g. method not matching), assume no parameters
6161
return false

src/jump.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,20 @@ function JuMP.add_variable(
6464
name * "_assess",
6565
),
6666
)
67+
68+
# forecast variables can't have bounds
69+
if JuMP.has_lower_bound(forecast.plan)
70+
@warn "Forecast variable lower bound will be removed."
71+
JuMP.delete_lower_bound(forecast.plan)
72+
JuMP.delete_lower_bound(forecast.assess)
73+
end
74+
75+
if JuMP.has_upper_bound(forecast.plan)
76+
@warn "Forecast variable upper bound will be removed."
77+
JuMP.delete_upper_bound(forecast.plan)
78+
JuMP.delete_upper_bound(forecast.assess)
79+
end
80+
6781
push!(model.forecast_vars, forecast)
6882
return forecast
6983
end

src/optimizers/bilevel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,10 @@ function solve_bilevel(
177177
ilayer = 1
178178
for layer in model.forecast.networks[1]
179179
if has_params(layer)
180-
for p in Flux.params(layer.weight)
180+
for p in Flux.trainables(layer.weight)
181181
p .= value.(predictive_model_vars[ilayer][:W])
182182
end
183-
for p in Flux.params(layer.bias)
183+
for p in Flux.trainables(layer.bias)
184184
p .= value.(predictive_model_vars[ilayer][:b])
185185
end
186186
end

src/optimizers/gradient.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ function train_with_gradient!(
4242
best_θ = extract_params(model.forecast)
4343
trace = Array{Float64}(undef, epochs)
4444
stochastic = batch_size > 0
45+
opt_state = Flux.setup(rule, model.forecast)
4546

4647
# precompute batches
4748
batches = repeat(1:T, outer = (1, epochs))'
@@ -87,7 +88,7 @@ function train_with_gradient!(
8788
end
8889

8990
# take gradient step
90-
apply_gradient!(model.forecast, dC, epochx, rule)
91+
apply_gradient!(model.forecast, dC, epochx, opt_state)
9192
end
9293

9394
# fix best model

src/optimizers/gradient_mpi.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ function train_with_gradient_mpi!(
3232
T = size(X)[1]
3333
stochastic = batch_size > 0
3434
compute_full_cost = true
35+
opt_state = Flux.setup(rule, model.forecast)
3536

3637
# precompute batches
3738
batches = repeat(1:T, outer = (1, epochs))'
@@ -121,7 +122,7 @@ function train_with_gradient_mpi!(
121122
end
122123

123124
# take gradient step (if not last epoch)
124-
apply_gradient!(model.forecast, dCdy, epochx, rule)
125+
apply_gradient!(model.forecast, dCdy, epochx, opt_state)
125126
end
126127

127128
# release workers

src/predictive_model.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Flux
22
using Statistics
33
import Zygote
4+
import Optimisers
45

56
"""
67
PredictiveModel(networks, input_output_map, input_size, output_size)
@@ -198,7 +199,7 @@ function apply_params(model::PredictiveModel, θ)
198199
end
199200

200201
"""
201-
apply_gradient!(model, dCdy, X, optimizer)
202+
apply_gradient!(model, dCdy, X, rule)
202203
203204
Apply a gradient vector to the model parameters.
204205
@@ -209,17 +210,16 @@ Apply a gradient vector to the model parameters.
209210
- `model::PredictiveModel`: model to be updated.
210211
- `dCdy::Vector{<:Real}`: gradient vector.
211212
- `X::Matrix{<:Real}`: input data.
212-
- `optimizer`: Optimiser to be used.
213+
- `rule`: Optimisation rule.
213214
...
214215
"""
215216
function apply_gradient!(
216217
model::PredictiveModel,
217218
dCdy::Vector{<:Real},
218219
X::Matrix{<:Real},
219-
optimizer,
220+
opt_state,
220221
)
221-
ps = Flux.params(model.networks)
222-
loss(x, y) = mean(dCdy'model(x))
223-
train_data = [(X', 0.0)]
224-
return Flux.train!(loss, ps, train_data, optimizer)
222+
loss3(m, X) = mean(dCdy'm(X'))
223+
grad = Zygote.gradient(loss3, model, X)[1]
224+
return Optimisers.update!(opt_state, model, grad)
225225
end

src/simulation.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,33 @@ function compute_single_step_cost(
33
y::Vector{<:Real},
44
yhat::Vector{<:Real},
55
)
6+
# set forecast params as prediction output
67
MOI.set.(model.plan, POI.ParameterValue(), model.plan_forecast_params, yhat)
8+
# optimize plan model
79
optimize!(model.plan)
8-
@assert termination_status(model.plan) == MOI.OPTIMAL "Optimization failed for PLAN model"
10+
# check for solution and fix assess policy vars
11+
try
12+
set_normalized_rhs.(
13+
model.assess[:assess_policy_fix],
14+
value.(plan_policy_vars(model)),
15+
)
16+
catch e
17+
println("Optimization failed for PLAN model.")
18+
throw(e)
19+
end
20+
# fix assess forecast vars on observer values
921
fix.(assess_forecast_vars(model), y; force = true)
10-
set_normalized_rhs.(
11-
model.assess[:assess_policy_fix],
12-
value.(plan_policy_vars(model)),
13-
)
22+
# optimize assess model
1423
optimize!(model.assess)
15-
@assert termination_status(model.assess) == MOI.OPTIMAL "Optimization failed for ASSESS model"
16-
return objective_value(model.assess)
24+
# check for optimization
25+
try
26+
return objective_value(model.assess)
27+
catch e
28+
println("Optimization failed for ASSESS model")
29+
throw(e)
30+
end
31+
# should never get here
32+
return 0
1733
end
1834

1935
"""

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ JobQueueMPI = "32d208e1-246e-420c-b6ff-18b71b410923"
88
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
99
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1010
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
11+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1112
ParametricOptInterface = "0ce4ce61-57bf-432b-a095-efac525d185e"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1314
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

0 commit comments

Comments
 (0)