Skip to content

Commit 0480365

Browse files
committed
fix Flux code to version 0.16.3 and update deps
1 parent 99f8cb8 commit 0480365

File tree

6 files changed

+20
-15
lines changed

6 files changed

+20
-15
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@ JobQueueMPI = "32d208e1-246e-420c-b6ff-18b71b410923"
1111
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
1212
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1313
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
14+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1415
ParametricOptInterface = "0ce4ce61-57bf-432b-a095-efac525d185e"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

1819
[compat]
1920
BilevelJuMP = "0.6.2"
2021
DiffOpt = "0.5.0"
21-
Flux = "0.14.25"
22+
Flux = "0.16.3"
2223
JobQueueMPI = "0.1.1"
2324
JuMP = "1.24"
24-
Optim = "1.11"
2525
MPI = "0.20.22"
26+
Optim = "1.11"
27+
Optimisers = "0.4.5"
2628
ParametricOptInterface = "0.9.0"
2729
Zygote = "0.6.75"
2830
julia = "1.10"

src/optimizers/gradient.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ function train_with_gradient!(
4242
best_θ = extract_params(model.forecast)
4343
trace = Array{Float64}(undef, epochs)
4444
stochastic = batch_size > 0
45+
opt_state = Flux.setup(rule, model.forecast)
4546

4647
# precompute batches
4748
batches = repeat(1:T, outer = (1, epochs))'
@@ -87,7 +88,7 @@ function train_with_gradient!(
8788
end
8889

8990
# take gradient step
90-
apply_gradient!(model.forecast, dC, epochx, rule)
91+
apply_gradient!(model.forecast, dC, epochx, opt_state)
9192
end
9293

9394
# fix best model

src/optimizers/gradient_mpi.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ function train_with_gradient_mpi!(
3232
T = size(X)[1]
3333
stochastic = batch_size > 0
3434
compute_full_cost = true
35+
opt_state = Flux.setup(rule, model.forecast)
3536

3637
# precompute batches
3738
batches = repeat(1:T, outer = (1, epochs))'
@@ -121,7 +122,7 @@ function train_with_gradient_mpi!(
121122
end
122123

123124
# take gradient step (if not last epoch)
124-
apply_gradient!(model.forecast, dCdy, epochx, rule)
125+
apply_gradient!(model.forecast, dCdy, epochx, opt_state)
125126
end
126127

127128
# release workers

src/predictive_model.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Flux
22
using Statistics
33
import Zygote
4+
import Optimisers
45

56
"""
67
PredictiveModel(networks, input_output_map, input_size, output_size)
@@ -198,7 +199,7 @@ function apply_params(model::PredictiveModel, θ)
198199
end
199200

200201
"""
201-
apply_gradient!(model, dCdy, X, optimizer)
202+
apply_gradient!(model, dCdy, X, rule)
202203
203204
Apply a gradient vector to the model parameters.
204205
@@ -209,17 +210,16 @@ Apply a gradient vector to the model parameters.
209210
- `model::PredictiveModel`: model to be updated.
210211
- `dCdy::Vector{<:Real}`: gradient vector.
211212
- `X::Matrix{<:Real}`: input data.
212-
- `optimizer`: Optimiser to be used.
213+
- `rule`: Optimisation rule.
213214
...
214215
"""
215216
function apply_gradient!(
216217
model::PredictiveModel,
217218
dCdy::Vector{<:Real},
218219
X::Matrix{<:Real},
219-
optimizer,
220+
opt_state,
220221
)
221-
ps = Flux.params(model.networks)
222-
loss(x, y) = mean(dCdy'model(x))
223-
train_data = [(X', 0.0)]
224-
return Flux.train!(loss, ps, train_data, optimizer)
222+
loss3(m, X) = mean(dCdy'm(X'))
223+
grad = Zygote.gradient(loss3, model, X)[1]
224+
Optimisers.update!(opt_state, model, grad)
225225
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ JobQueueMPI = "32d208e1-246e-420c-b6ff-18b71b410923"
88
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
99
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1010
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
11+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1112
ParametricOptInterface = "0ce4ce61-57bf-432b-a095-efac525d185e"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1314
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

test/test_predictive_model.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ out_size = 2
2424
forecaster,
2525
ones(out_size),
2626
ones((1, in_size)),
27-
Flux.Descent(0.1),
27+
Flux.setup(Flux.Descent(0.1), forecaster),
2828
)
2929
@test Flux.params(forecaster.networks[1])[1] ==
3030
0.9 * ones((out_size, in_size))
@@ -53,7 +53,7 @@ end
5353
forecaster,
5454
ones(out_size),
5555
ones((1, in_size)),
56-
Flux.Descent(0.1),
56+
Flux.setup(Flux.Descent(0.1), forecaster),
5757
)
5858
@test Flux.params(forecaster.networks[1])[1] ==
5959
0.9 * ones((out_size, in_size))
@@ -84,7 +84,7 @@ end
8484
forecaster,
8585
ones(out_size),
8686
ones((1, in_size)),
87-
Flux.Descent(0.1),
87+
Flux.setup(Flux.Descent(0.1), forecaster),
8888
)
8989
@test Flux.params(forecaster.networks[1])[1] ==
9090
0.8 * ones((model_out_size, model_in_size))
@@ -122,7 +122,7 @@ end
122122
forecaster,
123123
ones(out_size),
124124
ones((1, in_size)),
125-
Flux.Descent(0.1),
125+
Flux.setup(Flux.Descent(0.1), forecaster),
126126
)
127127
@test Flux.params(forecaster.networks[1])[1] ==
128128
0.9 * ones((model_out_size, model_in_size))

0 commit comments

Comments
 (0)