Skip to content

Commit cbdf36b

Browse files
authored
Add Flux 0.13 compatibility (#202)
* Add Flux 0.13 compatibility * `params` -> `Flux.params` * Forward-compatible (Flux 0.13 and 0.12) fix for tabular model * Subtype `AbstractOptimiser`
1 parent f51ecb6 commit cbdf36b

File tree

6 files changed

+8
-7
lines changed

6 files changed

+8
-7
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
2828
MosaicViews = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389"
2929
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
3030
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
31+
ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
3132
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
3233
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
3334
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
@@ -51,7 +52,7 @@ DataLoaders = "0.1"
5152
FileIO = "1.7"
5253
FilePathsBase = "0.9"
5354
FixedPointNumbers = "0.8"
54-
Flux = "0.12"
55+
Flux = "0.12, 0.13"
5556
FluxTraining = "0.2"
5657
Glob = "1"
5758
ImageInTerminal = "0.4"

src/Tabular/models.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function TabularModel(
4747

4848
tabularbackbone = Parallel(vcat, catbackbone, contbackbone)
4949

50-
classifierin = mapreduce(layer -> size(layer.weight)[1], +, catbackbone[2].layers;
50+
classifierin = mapreduce(layer -> size(layer.weight)[1], +, Tuple(catbackbone[2].layers);
5151
init = contbackbone.chs)
5252
dropout_rates = Iterators.cycle(dropout_rates)
5353
classifiers = []
@@ -139,7 +139,7 @@ function tabular_embedding_backbone(embedding_sizes, dropout_rate=0.)
139139
emb_drop = iszero(dropout_rate) ? identity : Dropout(dropout_rate)
140140
Chain(
141141
x -> tuple(eachrow(x)...),
142-
Parallel(vcat, embedslist),
142+
Parallel(vcat, embedslist...),
143143
emb_drop
144144
)
145145
end

src/Vision/models/blocks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function visionhead(
3838
acts = vcat([relu for _ 1:n-2], [identity])
3939
pool = concat_pool ? AdaptiveConcatPool((1, 1)) : AdaptiveMeanPool((1, 1))
4040

41-
layers = [pool, flatten]
41+
layers = [pool, Flux.flatten]
4242

4343
for (h_in, h_out, act) in zip(hs, hs[2:end], acts)
4444
push!(layers, linbndrop(h_in, h_out, act=act, p=p))

src/training/discriminativelrs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ dlro = DiscriminativeLRs(paramgroups, Dict(1 => 0., 2 => 1.))
2323
o = Optimiser(dlro, Descent(0.1))
2424
```
2525
"""
26-
struct DiscriminativeLRs
26+
struct DiscriminativeLRs <: Flux.Optimise.AbstractOptimiser
2727
pg::ParamGroups
2828
factorfn
2929
end

src/training/lrfind.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function lrfind(
4646
withfields(
4747
learner,
4848
model = modelcheckpoint,
49-
params = params(modelcheckpoint),
49+
params = Flux.params(modelcheckpoint),
5050
optimizer = deepcopy(learner.optimizer)
5151
) do
5252

src/training/paramgroups.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ getgroup(pg::ParamGroups, x::AbstractArray) = get(pg.map, x, nothing)
3232

3333
function assigngroups!(pg::ParamGroups, grouper, m)
3434
for (group, m_) in group(grouper, m)
35-
for p in params(m_)
35+
for p in Flux.params(m_)
3636
pg.map[p] = group
3737
end
3838
end

0 commit comments

Comments
 (0)