Skip to content

Commit e61a963

Browse files
committed
Resolve conflicts
2 parents 6a0032b + 57beb23 commit e61a963

File tree

11 files changed

+322
-203
lines changed

11 files changed

+322
-203
lines changed

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Documenter, Flux, NNlib, Functors, MLUtils, BSON
22

33
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)
44
makedocs(modules = [Flux, NNlib, Functors, MLUtils, BSON],
5-
doctest = VERSION == v"1.5",
5+
doctest = false,
66
sitename = "Flux",
77
pages = ["Home" => "index.md",
88
"Building Models" =>

docs/src/utilities.md

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,58 @@
33
Flux provides utility functions which can be used to initialize your layers
44
or to regularly execute callback functions.
55

6-
## Layer Initialization
6+
## Layer Initialisation
77

8-
These are primarily useful if you are planning to write your own layers.
9-
Flux initializes convolutional layers and recurrent cells with `glorot_uniform`
10-
by default.
11-
To change the default on an applicable layer, pass the desired function with the
12-
`init` keyword. For example:
8+
Flux initialises convolutional layers and recurrent cells with `glorot_uniform` by default.
9+
Most layers accept a function as an `init` keyword, which replaces this default. For example:
1310

1411
```jldoctest; setup = :(using Flux)
15-
julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
16-
Conv((3, 3), 1 => 8, relu) # 80 parameters
12+
julia> conv = Conv((3, 3), 3 => 2, relu; init=Flux.glorot_normal)
13+
Conv((3, 3), 3 => 2, relu) # 56 parameters
14+
15+
julia> conv.bias
16+
2-element Vector{Float32}:
17+
0.0
18+
0.0
19+
```
20+
21+
Note that `init` creates the weight array, but not the bias vector.
22+
23+
Many of the initialisation functions accept keywords such as `gain`,
24+
and a random number generator. To make it easy to pass these to layers,
25+
there are methods which return a function:
26+
27+
```jldoctest; setup = :(using Flux, Random)
28+
julia> Dense(4 => 5, tanh; init=Flux.glorot_uniform(gain=2))
29+
Dense(4 => 5, tanh) # 25 parameters
30+
31+
julia> Dense(4 => 5, tanh; init=Flux.randn32(MersenneTwister(1)))
32+
Dense(4 => 5, tanh) # 25 parameters
1733
```
1834

1935
```@docs
2036
Flux.glorot_uniform
2137
Flux.glorot_normal
2238
Flux.kaiming_uniform
2339
Flux.kaiming_normal
40+
Flux.truncated_normal
2441
Flux.orthogonal
2542
Flux.sparse_init
43+
Flux.identity_init
44+
Flux.ones32
45+
Flux.rand32
2646
```
2747

2848
## Changing the type of model parameters
2949

50+
The default `eltype` for models is `Float32` since models are often trained/run on GPUs.
51+
The `eltype` of model `m` can be changed to `Float64` by `f64(m)`:
52+
3053
```@docs
3154
Flux.f64
3255
Flux.f32
3356
```
3457

35-
The default `eltype` for models is `Float32` since models are often trained/run on GPUs. The `eltype` of model `m` can be changed to `Float64` by `f64(m)`, or to `Float32` by `f32(m)`.
36-
3758
## Model Building
3859

3960
Flux provides some utility functions to help you generate models in an automated fashion.

perf/bench_utils.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
using BenchmarkTools
22
using Flux
33
using CUDA
4-
using Zygote: pullback
4+
using Zygote: pullback, ignore
55

66

77
fw(m, x) = m(x)
88
bw(back) = back(1f0)
9-
fwbw(m, ps, x) = gradient(() -> sum(m(x)), ps)
10-
9+
fwbw(m, ps, x) = gradient(() -> sum(fw(m, x)), ps)
10+
pb(m, ps, x) = pullback(() -> sum(fw(m, x)), ps)
11+
1112
function run_benchmark(model, x; cuda=true)
1213

1314
if cuda
@@ -16,7 +17,7 @@ function run_benchmark(model, x; cuda=true)
1617
end
1718

1819
ps = Flux.params(model)
19-
y, back = pullback(() -> sum(model(x)), ps)
20+
y, back = pb(model, ps, x)
2021

2122

2223
if cuda

perf/recurrent.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
2+
3+
struct RNNWrapper{T}
4+
rnn::T
5+
end
6+
Flux.@functor RNNWrapper
7+
8+
# Need to specialize for RNNWrapper.
9+
fw(r::RNNWrapper, X::Vector{<:AbstractArray}) = begin
10+
Flux.reset!(r.rnn)
11+
[r.rnn(x) for x in X]
12+
end
13+
14+
fw(r::RNNWrapper, X) = begin
15+
Flux.reset!(r.rnn)
16+
r.rnn(X)
17+
end
18+
19+
fwbw(r::RNNWrapper, ps, X::Vector{<:AbstractArray}) = gradient(ps) do
20+
y = fw(r, X)
21+
sum(sum(y))
22+
end
23+
24+
pb(r::RNNWrapper, ps, X::Vector{<:AbstractArray}) = pullback(ps) do
25+
y = fw(r, X)
26+
sum(sum(y))
27+
end
28+
29+
function rnn_benchmark_sweep(data_creator::Function, rnn_type)
30+
for n in [2, 20, 200, 1000], ts in [1, 4, 16, 64]
31+
x, x_n = data_creator(n, ts)
32+
model = RNNWrapper(rnn_type(n, n))
33+
34+
println("$rnn_type $x_n CPU n=$n, ts=$ts")
35+
run_benchmark(model, x, cuda=false)
36+
37+
println("$rnn_type $x_n CUDA n=$n, ts=$ts")
38+
try
39+
run_benchmark(model, x, cuda=true)
40+
catch ex
41+
@show typeof(ex)
42+
if ex isa OutOfGPUMemoryError
43+
@warn "Not enough GPU memory to run test"
44+
else
45+
rethrow(ex)
46+
end
47+
end
48+
end
49+
end
50+
51+
for rnn_type in [Flux.RNN, Flux.GRU, Flux.LSTM]
52+
rnn_benchmark_sweep(rnn_type) do n, ts
53+
[randn(Float32, n, n) for _ in 1:ts], "Vec"
54+
end
55+
end
56+
57+
for rnn_type in [Flux.RNN, Flux.GRU, Flux.LSTM]
58+
rnn_benchmark_sweep(rnn_type) do n, ts
59+
randn(Float32, n, n, ts), "Block"
60+
end
61+
end
62+

perf/runbenchmarks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ include("conv.jl")
1111

1212
@info "Benchmark VGG"
1313
include("vgg.jl")
14+
15+
@info "Benchmark Recurrent"
16+
include("recurrent.jl")

src/functor.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,16 @@ paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
213213
"""
214214
f32(m)
215215
216-
Convert the `eltype` of model's parameters to `Float32`.
216+
Converts the `eltype` of model's parameters to `Float32` (which is Flux's default).
217+
Recurses into structs marked with [`@functor`](@ref).
217218
"""
218219
f32(m) = paramtype(Float32, m)
219220

220221
"""
221222
f64(m)
222223
223-
Convert the `eltype` of model's parameters to `Float64`.
224+
Converts the `eltype` of model's parameters to `Float64`.
225+
Recurses into structs marked with [`@functor`](@ref).
224226
"""
225227
f64(m) = paramtype(Float64, m)
226228

src/layers/normalise.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,22 +164,21 @@ struct LayerNorm{F,D,T,N}
164164
affine::Bool
165165
end
166166

167-
function LayerNorm(sz, λ=identity; affine=true, ϵ=1f-5)
168-
sz = sz isa Integer ? (sz,) : sz
169-
diag = affine ? Diagonal(sz...) : nothing
170-
return LayerNorm(λ, diag, ϵ, sz, affine)
167+
function LayerNorm(sz, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
168+
diag = affine ? Diagonal(sz...) : identity
169+
return LayerNorm(λ, diag, ϵ, Tuple(sz), affine)
171170
end
172171

173172
@functor LayerNorm
174173

175174
function (a::LayerNorm)(x)
176-
x = normalise(x, dims=1:length(a.size), ϵ=a.ϵ)
177-
a.diag === nothing ? a.λ.(x) : a.λ.(a.diag(x))
175+
x = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
176+
return a.λ === identity ? x : a.λ.(x)
178177
end
179178

180179
function Base.show(io::IO, l::LayerNorm)
181180
print(io, "LayerNorm($(l.size)")
182-
l.λ == identity || print(io, ", $(l.λ)")
181+
l.λ === identity || print(io, ", ", l.λ)
183182
hasaffine(l) || print(io, ", affine=false")
184183
print(io, ")")
185184
end

src/layers/stateless.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given b
3333
Per default, `dims` is the last dimension.
3434
`ϵ` is a small additive factor added to the denominator for numerical stability.
3535
"""
36-
function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5))
36+
@inline function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5))
3737
μ = mean(x, dims=dims)
38-
# σ = std(x, dims=dims, mean=μ, corrected=false) # use this when Zygote#478 gets merged
39-
σ = std(x, dims=dims, corrected=false)
40-
return (x .- μ) ./.+ ϵ)
38+
σ = std(x, dims=dims, mean=μ, corrected=false)
39+
return @. (x - μ) /+ ϵ)
4140
end

src/optimise/optimisers.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Gradient descent optimizer with learning rate `η` and momentum `ρ`.
5151
- Learning rate (`η`): Amount by which gradients are discounted before updating
5252
the weights.
5353
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
54-
prominent direction, in effect dampening oscillations.
54+
prominent direction, in effect damping oscillations.
5555
5656
# Examples
5757
```julia
@@ -84,7 +84,7 @@ Gradient descent optimizer with learning rate `η` and Nesterov momentum `ρ`.
8484
- Learning rate (`η`): Amount by which gradients are discounted before updating
8585
the weights.
8686
- Nesterov momentum (`ρ`): Controls the acceleration of gradient descent in the
87-
prominent direction, in effect dampening oscillations.
87+
prominent direction, in effect damping oscillations.
8888
8989
# Examples
9090
```julia
@@ -121,7 +121,7 @@ generally don't need tuning.
121121
- Learning rate (`η`): Amount by which gradients are discounted before updating
122122
the weights.
123123
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
124-
prominent direction, in effect dampening oscillations.
124+
prominent direction, in effect damping oscillations.
125125
126126
# Examples
127127
```julia

0 commit comments

Comments
 (0)