Skip to content

Commit 8f91cf9

Browse files
committed
replace usage of deprecated Flux.params to Flux.trainables
1 parent 0480365 commit 8f91cf9

File tree

4 files changed

+18
-18
lines changed

4 files changed

+18
-18
lines changed

src/flux_utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Extract the parameters of a Flux model (Flux.Chain or Flux.Dense) into a single
77
vector.
88
"""
99
function extract_flux_params(model::Union{Flux.Chain,Flux.Dense})
10-
θ = Flux.params(model)
10+
θ = Flux.trainables(model)
1111
return reduce(vcat, [vec(p) for p in θ])
1212
end
1313

@@ -21,7 +21,7 @@ function fix_flux_params_single_model(
2121
θ::Vector{<:Real},
2222
)
2323
i = 1
24-
for p in Flux.params(model)
24+
for p in Flux.trainables(model)
2525
psize = prod(size(p))
2626
p .= reshape(θ[i:i+psize-1], size(p))
2727
i += psize
@@ -38,7 +38,7 @@ of parameters.
3838
function fix_flux_params_multi_model(models, θ::Vector{<:Real})
3939
i = 1
4040
for model in models
41-
for p in Flux.params(model)
41+
for p in Flux.trainables(model)
4242
psize = prod(size(p))
4343
p .= reshape(θ[i:i+psize-1], size(p))
4444
i += psize
@@ -54,8 +54,8 @@ Check if a Flux layer has parameters.
5454
"""
5555
function has_params(layer)
5656
try
57-
# Attempt to get parameters; if it works and isn't empty, return true
58-
return !isempty(Flux.params(layer))
57+
# Attempt to get trainable parameters; if it works and isn't empty, return true
58+
return !isempty(Flux.trainable(layer))
5959
catch e
6060
# If there is an error (e.g. method not matching), assume no parameters
6161
return false

src/optimizers/bilevel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,10 @@ function solve_bilevel(
177177
ilayer = 1
178178
for layer in model.forecast.networks[1]
179179
if has_params(layer)
180-
for p in Flux.params(layer.weight)
180+
for p in Flux.trainables(layer.weight)
181181
p .= value.(predictive_model_vars[ilayer][:W])
182182
end
183-
for p in Flux.params(layer.bias)
183+
for p in Flux.trainables(layer.bias)
184184
p .= value.(predictive_model_vars[ilayer][:b])
185185
end
186186
end

src/predictive_model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,5 +221,5 @@ function apply_gradient!(
221221
)
222222
loss3(m, X) = mean(dCdy'm(X'))
223223
grad = Zygote.gradient(loss3, model, X)[1]
224-
Optimisers.update!(opt_state, model, grad)
224+
return Optimisers.update!(opt_state, model, grad)
225225
end

test/test_predictive_model.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ out_size = 2
2626
ones((1, in_size)),
2727
Flux.setup(Flux.Descent(0.1), forecaster),
2828
)
29-
@test Flux.params(forecaster.networks[1])[1] ==
29+
@test Flux.trainables(forecaster)[1] ==
3030
0.9 * ones((out_size, in_size))
31-
@test Flux.params(forecaster.networks[1])[2] == 0.9 * ones(out_size)
31+
@test Flux.trainables(forecaster)[2] == 0.9 * ones(out_size)
3232
end
3333

3434
@testset "Single-Chain" begin
@@ -55,9 +55,9 @@ end
5555
ones((1, in_size)),
5656
Flux.setup(Flux.Descent(0.1), forecaster),
5757
)
58-
@test Flux.params(forecaster.networks[1])[1] ==
58+
@test Flux.trainables(forecaster)[1] ==
5959
0.9 * ones((out_size, in_size))
60-
@test Flux.params(forecaster.networks[1])[2] == 0.9 * ones(out_size)
60+
@test Flux.trainables(forecaster)[2] == 0.9 * ones(out_size)
6161
end
6262

6363
@testset "Multi-Variate-Dense" begin
@@ -86,9 +86,9 @@ end
8686
ones((1, in_size)),
8787
Flux.setup(Flux.Descent(0.1), forecaster),
8888
)
89-
@test Flux.params(forecaster.networks[1])[1] ==
89+
@test Flux.trainables(forecaster)[1] ==
9090
0.8 * ones((model_out_size, model_in_size))
91-
@test Flux.params(forecaster.networks[1])[2] == 0.8 * ones(model_out_size)
91+
@test Flux.trainables(forecaster)[2] == 0.8 * ones(model_out_size)
9292
end
9393

9494
@testset "Multi-Model-Dense" begin
@@ -124,10 +124,10 @@ end
124124
ones((1, in_size)),
125125
Flux.setup(Flux.Descent(0.1), forecaster),
126126
)
127-
@test Flux.params(forecaster.networks[1])[1] ==
127+
@test Flux.trainables(forecaster)[1] ==
128128
0.9 * ones((model_out_size, model_in_size))
129-
@test Flux.params(forecaster.networks[1])[2] == 0.9 * ones(model_out_size)
130-
@test Flux.params(forecaster.networks[2])[1] ==
129+
@test Flux.trainables(forecaster)[2] == 0.9 * ones(model_out_size)
130+
@test Flux.trainables(forecaster)[3] ==
131131
0.9 * ones((model_out_size, model_in_size))
132-
@test Flux.params(forecaster.networks[2])[2] == 0.9 * ones(model_out_size)
132+
@test Flux.trainables(forecaster)[4] == 0.9 * ones(model_out_size)
133133
end

0 commit comments

Comments
 (0)