Skip to content

Commit 4ef0199

Browse files
authored
Merge pull request #28 from LAMPSPUC/dev
Dev
2 parents 137da5e + 7d08453 commit 4ef0199

File tree

8 files changed

+37
-34
lines changed

8 files changed

+37
-34
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@ JobQueueMPI = "32d208e1-246e-420c-b6ff-18b71b410923"
1111
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
1212
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1313
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
14+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1415
ParametricOptInterface = "0ce4ce61-57bf-432b-a095-efac525d185e"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

1819
[compat]
1920
BilevelJuMP = "0.6.2"
2021
DiffOpt = "0.5.0"
21-
Flux = "0.14.25"
22+
Flux = "0.16.3"
2223
JobQueueMPI = "0.1.1"
2324
JuMP = "1.24"
24-
Optim = "1.11"
2525
MPI = "0.20.22"
26+
Optim = "1.11"
27+
Optimisers = "0.4.5"
2628
ParametricOptInterface = "0.9.0"
2729
Zygote = "0.6.75"
2830
julia = "1.10"

src/flux_utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Extract the parameters of a Flux model (Flux.Chain or Flux.Dense) into a single
77
vector.
88
"""
99
function extract_flux_params(model::Union{Flux.Chain,Flux.Dense})
10-
θ = Flux.params(model)
10+
θ = Flux.trainables(model)
1111
return reduce(vcat, [vec(p) for p in θ])
1212
end
1313

@@ -21,7 +21,7 @@ function fix_flux_params_single_model(
2121
θ::Vector{<:Real},
2222
)
2323
i = 1
24-
for p in Flux.params(model)
24+
for p in Flux.trainables(model)
2525
psize = prod(size(p))
2626
p .= reshape(θ[i:i+psize-1], size(p))
2727
i += psize
@@ -38,7 +38,7 @@ of parameters.
3838
function fix_flux_params_multi_model(models, θ::Vector{<:Real})
3939
i = 1
4040
for model in models
41-
for p in Flux.params(model)
41+
for p in Flux.trainables(model)
4242
psize = prod(size(p))
4343
p .= reshape(θ[i:i+psize-1], size(p))
4444
i += psize
@@ -54,8 +54,8 @@ Check if a Flux layer has parameters.
5454
"""
5555
function has_params(layer)
5656
try
57-
# Attempt to get parameters; if it works and isn't empty, return true
58-
return !isempty(Flux.params(layer))
57+
# Attempt to get trainable parameters; if it works and isn't empty, return true
58+
return !isempty(Flux.trainable(layer))
5959
catch e
6060
# If there is an error (e.g. method not matching), assume no parameters
6161
return false

src/optimizers/bilevel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,10 @@ function solve_bilevel(
177177
ilayer = 1
178178
for layer in model.forecast.networks[1]
179179
if has_params(layer)
180-
for p in Flux.params(layer.weight)
180+
for p in Flux.trainables(layer.weight)
181181
p .= value.(predictive_model_vars[ilayer][:W])
182182
end
183-
for p in Flux.params(layer.bias)
183+
for p in Flux.trainables(layer.bias)
184184
p .= value.(predictive_model_vars[ilayer][:b])
185185
end
186186
end

src/optimizers/gradient.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ function train_with_gradient!(
4242
best_θ = extract_params(model.forecast)
4343
trace = Array{Float64}(undef, epochs)
4444
stochastic = batch_size > 0
45+
opt_state = Flux.setup(rule, model.forecast)
4546

4647
# precompute batches
4748
batches = repeat(1:T, outer = (1, epochs))'
@@ -87,7 +88,7 @@ function train_with_gradient!(
8788
end
8889

8990
# take gradient step
90-
apply_gradient!(model.forecast, dC, epochx, rule)
91+
apply_gradient!(model.forecast, dC, epochx, opt_state)
9192
end
9293

9394
# fix best model

src/optimizers/gradient_mpi.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ function train_with_gradient_mpi!(
3232
T = size(X)[1]
3333
stochastic = batch_size > 0
3434
compute_full_cost = true
35+
opt_state = Flux.setup(rule, model.forecast)
3536

3637
# precompute batches
3738
batches = repeat(1:T, outer = (1, epochs))'
@@ -121,7 +122,7 @@ function train_with_gradient_mpi!(
121122
end
122123

123124
# take gradient step (if not last epoch)
124-
apply_gradient!(model.forecast, dCdy, epochx, rule)
125+
apply_gradient!(model.forecast, dCdy, epochx, opt_state)
125126
end
126127

127128
# release workers

src/predictive_model.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Flux
22
using Statistics
33
import Zygote
4+
import Optimisers
45

56
"""
67
PredictiveModel(networks, input_output_map, input_size, output_size)
@@ -198,7 +199,7 @@ function apply_params(model::PredictiveModel, θ)
198199
end
199200

200201
"""
201-
apply_gradient!(model, dCdy, X, optimizer)
202+
apply_gradient!(model, dCdy, X, rule)
202203
203204
Apply a gradient vector to the model parameters.
204205
@@ -209,17 +210,16 @@ Apply a gradient vector to the model parameters.
209210
- `model::PredictiveModel`: model to be updated.
210211
- `dCdy::Vector{<:Real}`: gradient vector.
211212
- `X::Matrix{<:Real}`: input data.
212-
- `optimizer`: Optimiser to be used.
213+
- `rule`: Optimisation rule.
213214
...
214215
"""
215216
function apply_gradient!(
216217
model::PredictiveModel,
217218
dCdy::Vector{<:Real},
218219
X::Matrix{<:Real},
219-
optimizer,
220+
opt_state,
220221
)
221-
ps = Flux.params(model.networks)
222-
loss(x, y) = mean(dCdy'model(x))
223-
train_data = [(X', 0.0)]
224-
return Flux.train!(loss, ps, train_data, optimizer)
222+
loss3(m, X) = mean(dCdy'm(X'))
223+
grad = Zygote.gradient(loss3, model, X)[1]
224+
return Optimisers.update!(opt_state, model, grad)
225225
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ JobQueueMPI = "32d208e1-246e-420c-b6ff-18b71b410923"
88
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
99
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1010
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
11+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1112
ParametricOptInterface = "0ce4ce61-57bf-432b-a095-efac525d185e"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1314
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

test/test_predictive_model.jl

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@ out_size = 2
2424
forecaster,
2525
ones(out_size),
2626
ones((1, in_size)),
27-
Flux.Descent(0.1),
27+
Flux.setup(Flux.Descent(0.1), forecaster),
2828
)
29-
@test Flux.params(forecaster.networks[1])[1] ==
30-
0.9 * ones((out_size, in_size))
31-
@test Flux.params(forecaster.networks[1])[2] == 0.9 * ones(out_size)
29+
@test Flux.trainables(forecaster)[1] == 0.9 * ones((out_size, in_size))
30+
@test Flux.trainables(forecaster)[2] == 0.9 * ones(out_size)
3231
end
3332

3433
@testset "Single-Chain" begin
@@ -53,11 +52,10 @@ end
5352
forecaster,
5453
ones(out_size),
5554
ones((1, in_size)),
56-
Flux.Descent(0.1),
55+
Flux.setup(Flux.Descent(0.1), forecaster),
5756
)
58-
@test Flux.params(forecaster.networks[1])[1] ==
59-
0.9 * ones((out_size, in_size))
60-
@test Flux.params(forecaster.networks[1])[2] == 0.9 * ones(out_size)
57+
@test Flux.trainables(forecaster)[1] == 0.9 * ones((out_size, in_size))
58+
@test Flux.trainables(forecaster)[2] == 0.9 * ones(out_size)
6159
end
6260

6361
@testset "Multi-Variate-Dense" begin
@@ -84,11 +82,11 @@ end
8482
forecaster,
8583
ones(out_size),
8684
ones((1, in_size)),
87-
Flux.Descent(0.1),
85+
Flux.setup(Flux.Descent(0.1), forecaster),
8886
)
89-
@test Flux.params(forecaster.networks[1])[1] ==
87+
@test Flux.trainables(forecaster)[1] ==
9088
0.8 * ones((model_out_size, model_in_size))
91-
@test Flux.params(forecaster.networks[1])[2] == 0.8 * ones(model_out_size)
89+
@test Flux.trainables(forecaster)[2] == 0.8 * ones(model_out_size)
9290
end
9391

9492
@testset "Multi-Model-Dense" begin
@@ -122,12 +120,12 @@ end
122120
forecaster,
123121
ones(out_size),
124122
ones((1, in_size)),
125-
Flux.Descent(0.1),
123+
Flux.setup(Flux.Descent(0.1), forecaster),
126124
)
127-
@test Flux.params(forecaster.networks[1])[1] ==
125+
@test Flux.trainables(forecaster)[1] ==
128126
0.9 * ones((model_out_size, model_in_size))
129-
@test Flux.params(forecaster.networks[1])[2] == 0.9 * ones(model_out_size)
130-
@test Flux.params(forecaster.networks[2])[1] ==
127+
@test Flux.trainables(forecaster)[2] == 0.9 * ones(model_out_size)
128+
@test Flux.trainables(forecaster)[3] ==
131129
0.9 * ones((model_out_size, model_in_size))
132-
@test Flux.params(forecaster.networks[2])[2] == 0.9 * ones(model_out_size)
130+
@test Flux.trainables(forecaster)[4] == 0.9 * ones(model_out_size)
133131
end

0 commit comments

Comments
 (0)