Skip to content

Commit 9050ef0

Browse files
authored
Add WeightNorm reparametrization (#2550)
1 parent 2bbd8b3 commit 9050ef0

File tree

8 files changed

+219
-3
lines changed

8 files changed

+219
-3
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.
44

5+
## v0.15.3
6+
* Add `WeightNorm` normalization layer.
7+
58
## v0.15.0 (December 2024)
69
This release includes two **breaking changes**:
710
- The recurrent layers have been thoroughly revised. See below and read the [documentation](https://fluxml.ai/Flux.jl/v0.15/guide/models/recurrence/) for details.

docs/src/reference/models/layers.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ AlphaDropout
126126
LayerNorm
127127
InstanceNorm
128128
GroupNorm
129+
WeightNorm
130+
Flux.remove_weight_norms
129131
Flux.normalise
130132
```
131133

ext/FluxAMDGPUExt/FluxAMDGPUExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ module FluxAMDGPUExt
33
import ChainRulesCore
44
import ChainRulesCore: NoTangent
55
import Flux
6-
import Flux: adapt_storage, fmap
7-
import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
6+
import Flux: fmap, DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
87
import NNlib
98
using MLDataDevices
109
using AMDGPU

src/Flux.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
4242
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
4343
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
4444
Dropout, AlphaDropout,
45-
LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
45+
LayerNorm, BatchNorm, InstanceNorm, GroupNorm, WeightNorm,
4646
MultiHeadAttention,
4747
Upsample, PixelShuffle,
4848
fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32,
@@ -94,6 +94,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
9494
siamese_contrastive_loss,
9595
squared_hinge_loss,
9696
tversky_loss,
97+
remove_weight_norms,
9798
))
9899

99100
include("gradient.jl")

src/layers/normalise.jl

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,3 +568,127 @@ scale parameters, `false` otherwise.
568568
See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).
569569
"""
570570
hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine
571+
572+
struct WeightNorm{L, G, D}
573+
layer::L
574+
g::G
575+
576+
which::Symbol
577+
dims::D
578+
end
579+
@layer WeightNorm
580+
581+
"""
582+
WeightNorm(layer::L, which::Symbol = :weight; dims = -1)
583+
584+
Apply weight normalization to a parameter given by `which` in a `layer`.
585+
586+
``w = g \\frac{\\mathbf{v}}{\\lVert \\mathbf{v} \\rVert}``
587+
588+
Decouples the magnitude of a weight tensor from its direction.
589+
By default, normalization is applied along the output channel `dim=-1`
590+
(equivalent to `dims=ndims(w)`).
591+
592+
### Example
593+
594+
```jldoctest
595+
julia> c = Conv((3,), 1 => 2);
596+
597+
julia> wc = WeightNorm(c, :weight)
598+
WeightNorm(
599+
Conv((3,), 1 => 2), # 8 parameters
600+
3×1×1 Array{Float32,...}, # 3 parameters
601+
:weight,
602+
3,
603+
) # Total: 3 arrays, 11 parameters, 276 bytes.
604+
605+
julia> x = ones(Float32, 12, 1, 1);
606+
607+
julia> c(x) ≈ wc(x) # forward pass is the same as with the original layer
608+
true
609+
```
610+
611+
# Reference
612+
613+
Salimans & Kingma, _Weight Normalization_ (2016) <https://arxiv.org/abs/1602.07868>
614+
"""
615+
function WeightNorm(layer::L, which::Symbol = :weight; dims = -1) where L
616+
hasfield(L, which) || throw(ArgumentError("`$L` does not have field `:$which`."))
617+
618+
x = getfield(layer, which)
619+
iszero(x) && throw(ArgumentError(
620+
"`$which` field for `$(typeof(layer))` is all zero, which will result in NaN."))
621+
622+
d = if dims isa Colon
623+
1:ndims(x)
624+
elseif dims == -1
625+
dims = ndims(x)
626+
else
627+
dims
628+
end
629+
630+
g = sqrt.(sum(abs2, x; dims) .+ eps(eltype(x)))
631+
WeightNorm(layer, g, which, dims)
632+
end
633+
634+
(w::WeightNorm)(x) = reparametrize(w)(x)
635+
636+
"""
637+
reparametrize(wn::WeightNorm)
638+
639+
Apply `WeightNorm` reparametrization and return underlying `layer`.
640+
"""
641+
function reparametrize(wn::WeightNorm)
642+
ϵ = eps(eltype(wn.g))
643+
v = getfield(wn.layer, wn.which)
644+
n2 = sum(abs2, v; wn.dims)
645+
w = @. wn.g * v / sqrt(n2 + ϵ)
646+
647+
fields, ctor = Functors.functor(wn.layer)
648+
return ctor(merge(
649+
fields, NamedTuple{(wn.which,)}((w,)),
650+
))
651+
end
652+
653+
function Base.show(io::IO, w::WeightNorm)
654+
print(io, "WeightNorm(")
655+
Base.show(io, w.layer)
656+
print(io, ", :", w.which, "; dims=", w.dims)
657+
print(io, ")")
658+
end
659+
660+
"""
661+
remove_weight_norms(x)
662+
663+
Remove any [WeightNorm](@ref) parametrization in the model.
664+
665+
### Example
666+
667+
```jldoctest
668+
julia> model = Chain(
669+
WeightNorm(Conv((3,), 1 => 2), :weight),
670+
WeightNorm(Conv((3,), 2 => 2), :weight),
671+
)
672+
Chain(
673+
WeightNorm(
674+
Conv((3,), 1 => 2), # 8 parameters
675+
3×1×1 Array{Float32,...}, # 3 parameters
676+
:weight,
677+
3,
678+
),
679+
WeightNorm(
680+
Conv((3,), 2 => 2), # 14 parameters
681+
3×2×1 Array{Float32,...}, # 6 parameters
682+
:weight,
683+
3,
684+
),
685+
) # Total: 6 arrays, 31 parameters, 588 bytes.
686+
687+
julia> Flux.remove_weight_norms(model)
688+
Chain(
689+
Conv((3,), 1 => 2), # 8 parameters
690+
Conv((3,), 2 => 2), # 14 parameters
691+
) # Total: 4 arrays, 22 parameters, 392 bytes.
692+
```
693+
"""
694+
remove_weight_norms(x) = fmap(reparametrize, x; exclude=l -> l isa WeightNorm)

src/layers/recurrent.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ x = rand(Float32, 10)
103103
104104
# Run forward
105105
res = rnn(x, h0)
106+
```
106107
"""
107108
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))
108109

test/runtests.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,20 @@ include("test_utils.jl") # for test_gradients
2525

2626
Random.seed!(0)
2727

28+
include("testsuite/normalization.jl")
29+
30+
function flux_testsuite(dev)
31+
@testset "Flux Test Suite" begin
32+
@testset "Normalization" begin
33+
normalization_testsuite(dev)
34+
end
35+
end
36+
end
37+
2838
@testset verbose=true "Flux.jl" begin
2939
if get(ENV, "FLUX_TEST_CPU", "true") == "true"
40+
flux_testsuite(cpu)
41+
3042
@testset "Utils" begin
3143
include("utils.jl")
3244
end
@@ -84,6 +96,8 @@ Random.seed!(0)
8496
if CUDA.functional()
8597
@testset "CUDA" begin
8698
include("ext_cuda/runtests.jl")
99+
100+
flux_testsuite(gpu)
87101
end
88102
else
89103
@warn "CUDA.jl package is not functional. Skipping CUDA tests."
@@ -99,6 +113,8 @@ Random.seed!(0)
99113
if AMDGPU.functional() && AMDGPU.functional(:MIOpen)
100114
@testset "AMDGPU" begin
101115
include("ext_amdgpu/runtests.jl")
116+
117+
flux_testsuite(gpu)
102118
end
103119
else
104120
@info "AMDGPU.jl package is not functional. Skipping AMDGPU tests."
@@ -114,6 +130,8 @@ Random.seed!(0)
114130
if Metal.functional()
115131
@testset "Metal" begin
116132
include("ext_metal/runtests.jl")
133+
134+
flux_testsuite(gpu)
117135
end
118136
else
119137
@info "Metal.jl package is not functional. Skipping Metal tests."

test/testsuite/normalization.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
function normalization_testsuite(dev)
2+
@testset "WeightNorm" begin
3+
x = rand(Float32, 1, 3) |> dev
4+
mn = WeightNorm(Dense(1 => 2)) |> dev
5+
m = Flux.remove_weight_norms(mn)
6+
@test m(x) mn(x)
7+
8+
@test_throws ArgumentError WeightNorm(m, :weights)
9+
@test_throws "does not have field" WeightNorm(m, :weights)
10+
11+
@test_throws ArgumentError WeightNorm(m, :bias)
12+
@test_throws "is all zero" WeightNorm(m, :bias)
13+
14+
og = (Zygote.gradient(m) do m
15+
sum(m(x))
16+
end)[1]
17+
g = (Zygote.gradient(mn) do mn
18+
sum(mn(x))
19+
end)[1]
20+
21+
@test g.layer.weight nothing # Original weight acts as a direction `v`.
22+
@test g.layer.bias nothing
23+
@test g.g nothing
24+
25+
# Compare gradients with original layer.
26+
27+
v = mn.layer.weight
28+
ϵ = eps(eltype(v))
29+
n2 = sum(abs2, v; dims=2)
30+
v = v ./ sqrt.(n2 .+ ϵ)
31+
32+
@test (og.weight .* v) g.g
33+
@test (og.weight .* mn.g .- mn.g .* g.g .* v) g.layer.weight atol=1f-6
34+
35+
# Test WeightNorm removal.
36+
37+
om = Flux.remove_weight_norms(mn)
38+
@test om isa Dense
39+
@test om.weight m.weight
40+
@test om.bias m.bias
41+
42+
# Test with Chain.
43+
44+
c = Chain(
45+
WeightNorm(Conv((3,), 1 => 2)),
46+
Conv((3,), 2 => 2),
47+
WeightNorm(Conv((3,), 2 => 3)),
48+
x -> reshape(x, 18, :),
49+
WeightNorm(Dense(18, 4)),
50+
Dense(4, 1),
51+
)
52+
@test c[1] isa WeightNorm
53+
@test c[2] isa Conv
54+
@test c[3] isa WeightNorm
55+
@test c[5] isa WeightNorm
56+
@test c[6] isa Dense
57+
58+
oc = Flux.remove_weight_norms(c)
59+
@test oc[1] isa Conv
60+
@test oc[2] isa Conv
61+
@test oc[3] isa Conv
62+
@test oc[5] isa Dense
63+
@test oc[6] isa Dense
64+
65+
x = rand(Float32, 12, 1, 1)
66+
@test c(x) oc(x)
67+
end
68+
end

0 commit comments

Comments
 (0)