Skip to content

Commit 68b165c

Browse files
authored
Add f16 (#2184)
* add f16 * rm duplicate ones32 * add ones16, rand16, etc * better docstrings * make _match_eltype noisy about 32 -> 16 * add a few tests * add f16 to docs * fixes * more tests * news * also remove some adapt piracy * fixes & tests * Revert "add ones16, rand16, etc" This reverts commit 7d2e8f1. * rm a test * fixup * fix promotion in BatchNorm
1 parent 2d357ad commit 68b165c

File tree

13 files changed

+178
-50
lines changed

13 files changed

+178
-50
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Flux Release Notes
22

3+
## v0.13.13
4+
* Added `f16` which changes precision to `Float16`, recursively.
5+
36
## v0.13.12
47
* CUDA.jl 4.0 compatibility.
58

docs/src/utilities.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,5 @@ The `eltype` of model `m` can be changed to `Float64` by `f64(m)`:
6161
```@docs
6262
Flux.f64
6363
Flux.f32
64+
Flux.f16
6465
```

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion
2424
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
2525
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
2626
Upsample, PixelShuffle,
27-
fmap, cpu, gpu, f32, f64, rand32, randn32, zeros32, ones32,
27+
fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32,
2828
testmode!, trainmode!
2929

3030
include("optimise/Optimise.jl")

src/deprecations.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ Base.@deprecate_binding ADADelta AdaDelta
8484
# Remove sub-module Data, while making sure Flux.Data.DataLoader keeps working
8585
Base.@deprecate_binding Data Flux false "Sub-module Flux.Data has been removed. The only thing it contained may be accessed as Flux.DataLoader"
8686

87+
@deprecate paramtype(T,m) _paramtype(T,m) false # internal method, renamed to make this clear
88+
8789
@deprecate rng_from_array() default_rng_value()
8890

8991
function istraining()

src/functor.jl

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -146,24 +146,26 @@ ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AbstractArray) =
146146
"""
147147
cpu(m)
148148
149-
Moves `m` onto the CPU, the opposite of [`gpu`](@ref).
149+
Copies `m` onto the CPU, the opposite of [`gpu`](@ref).
150150
Recurses into structs marked [`@functor`](@ref).
151151
152+
# Example
152153
```julia-repl
153-
julia> m = Dense(1,2)
154-
Dense(1, 2)
154+
julia> m_gpu = Dense(CUDA.randn(2, 5))
155+
Dense(5 => 2) # 12 parameters
155156
156-
julia> m_gpu = gpu(m)
157-
Dense(1, 2)
157+
julia> m_gpu.bias # matches the given weight matrix
158+
2-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
159+
0.0
160+
0.0
158161
159-
julia> typeof(m_gpu.W)
160-
CuArray{Float32, 2}
162+
julia> m = m_gpu |> cpu
163+
Dense(5 => 2) # 12 parameters
161164
162-
julia> m_cpu = cpu(m_gpu)
163-
Dense(1, 2)
164-
165-
julia> typeof(m_cpu.W)
166-
Matrix{Float32}
165+
julia> m.bias
166+
2-element Vector{Float32}:
167+
0.0
168+
0.0
167169
```
168170
"""
169171
cpu(x) = fmap(x -> adapt(FluxCPUAdaptor(), x), x, exclude = _isleaf)
@@ -178,24 +180,32 @@ _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)
178180
"""
179181
gpu(x)
180182
181-
Moves `m` to the current GPU device, if available. It is a no-op otherwise.
183+
Copies `m` to the current GPU device, if one is available.
184+
If no GPU is available, it does nothing (but prints a warning the first time).
185+
186+
On arrays, this calls CUDA's `cu`, which also changes arrays
187+
with Float64 elements to Float32 while copying them to the device.
188+
To act on arrays within a struct, the struct type must be marked with [`@functor`](@ref).
189+
190+
Use [`cpu`](@ref) to copy back to ordinary `Array`s.
191+
See also [`f32`](@ref) and [`f16`](@ref) to change element type only.
192+
182193
See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/)
183194
to help identify the current device.
184195
185-
This works for functions, and any struct marked with [`@functor`](@ref).
186-
196+
# Example
187197
```julia-repl
188-
julia> m = Dense(1,2)
189-
Dense(1, 2)
198+
julia> m = Dense(rand(2, 3)) # constructed with Float64 weight matrix
199+
Dense(3 => 2) # 8 parameters
190200
191-
julia> typeof(m.W)
192-
Matrix{Float32}
201+
julia> typeof(m.weight)
202+
Matrix{Float64} (alias for Array{Float64, 2})
193203
194-
julia> m_gpu = gpu(m)
195-
Dense(1, 2)
204+
julia> m_gpu = gpu(m) # can equivalently be written m_gpu = m |> gpu
205+
Dense(3 => 2) # 8 parameters
196206
197-
julia> typeof(m_gpu.W) # notice the type of the array changed to a CuArray
198-
CuArray{Float32, 2}
207+
julia> typeof(m_gpu.weight)
208+
CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
199209
```
200210
"""
201211
function gpu(x)
@@ -216,25 +226,55 @@ ChainRulesCore.@non_differentiable check_use_cuda()
216226

217227
# Precision
218228

219-
adapt_storage(T::Type{<:Real}, xs::AbstractArray{<:Real}) = convert.(T, xs) # piracy
229+
struct FluxEltypeAdaptor{T} end
230+
231+
Adapt.adapt_storage(::FluxEltypeAdaptor{T}, x::AbstractArray{<:Number}) where T = convert(AbstractArray{T}, x)
220232

221-
paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
233+
_paramtype(::Type{T}, m) where T = fmap(adapt(FluxEltypeAdaptor{T}()), m)
234+
_paramtype(::Type{T}, x::AbstractArray{<:Real}) where {T} = convert(AbstractArray{T}, x)
222235

223236
"""
224237
f32(m)
225238
226239
Converts the `eltype` of model's parameters to `Float32` (which is Flux's default).
227240
Recurses into structs marked with [`@functor`](@ref).
241+
See also [`f64`](@ref) and [`f16`](@ref).
228242
"""
229-
f32(m) = paramtype(Float32, m)
243+
f32(m) = _paramtype(Float32, m)
230244

231245
"""
232246
f64(m)
233247
234248
Converts the `eltype` of model's parameters to `Float64`.
235249
Recurses into structs marked with [`@functor`](@ref).
236250
"""
237-
f64(m) = paramtype(Float64, m)
251+
f64(m) = _paramtype(Float64, m)
252+
253+
"""
254+
f16(m)
255+
256+
Converts the `eltype` of model's parameters to `Float16`.
257+
Recurses into structs marked with [`@functor`](@ref).
258+
259+
Support for `Float16` is limited on many CPUs. Julia may
260+
convert to `Float32` for each operation, which is slow.
261+
262+
# Example
263+
```jldoctest
264+
julia> m = Chain(Dense(784, 2048, relu), Dense(2048, 10)) # all Float32
265+
Chain(
266+
Dense(784 => 2048, relu), # 1_607_680 parameters
267+
Dense(2048 => 10), # 20_490 parameters
268+
) # Total: 4 arrays, 1_628_170 parameters, 6.211 MiB.
269+
270+
julia> m |> f16 # takes half the memory
271+
Chain(
272+
Dense(784 => 2048, relu), # 1_607_680 parameters
273+
Dense(2048 => 10), # 20_490 parameters
274+
) # Total: 4 arrays, 1_628_170 parameters, 3.106 MiB.
275+
```
276+
"""
277+
f16(m) = _paramtype(Float16, m)
238278

239279
# Functors for certain Julia data structures
240280
@functor Cholesky

src/layers/normalise.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ function (a::LayerNorm)(x::AbstractArray)
194194
_size_check(a, x, d => size(a.diag.scale, d))
195195
end
196196
end
197-
a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
197+
eps = convert(float(eltype(x)), a.ϵ) # avoids promotion for Float16 data, but should ε chage too?
198+
a.diag(normalise(x, dims=1:length(a.size), ϵ=eps))
198199
end
199200

200201
function Base.show(io::IO, l::LayerNorm)
@@ -223,7 +224,8 @@ function _norm_layer_forward(
223224
end
224225
end
225226

226-
o = _norm_layer_forward(x, μ, σ², l.ϵ)
227+
eps = convert(float(T), l.ϵ)
228+
o = _norm_layer_forward(x, μ, σ², eps)
227229
hasaffine(l) || return l.λ.(o)
228230

229231
γ = reshape(l.γ, affine_shape)

src/layers/stateless.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ function _match_eltype(layer, ::Type{Float32}, x::AbstractArray{Float64})
7979
convert(AbstractArray{Float32}, x)
8080
end
8181

82+
# Bug in Float16 use?
83+
function _match_eltype(layer, ::Type{Float16}, x::AbstractArray{Float32})
84+
@warn "Layer with Float16 parameters got Float32 input.
85+
The input will be converted, but may indicate a problem in earlier layers." layer summary(x) maxlog=1
86+
convert(AbstractArray{Float16}, x)
87+
end
88+
8289
# Allow OneHot to reach specialisation of * etc:
8390
_match_eltype(layer, ::Type, x::OneHotLike) = x
8491

src/utils.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,9 +468,6 @@ identity_init(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (args...;k
468468

469469
ChainRulesCore.@non_differentiable identity_init(::Any...)
470470

471-
ones32(dims::Integer...) = Base.ones(Float32, dims...)
472-
zeros32(dims::Integer...) = Base.zeros(Float32, dims...)
473-
474471
"""
475472
ones32(size...) = ones(Float32, size...)
476473

test/cuda/cudnn.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,28 @@ using Flux, CUDA, Test
22
using Flux: pullback
33

44
@testset "CUDNN BatchNorm" begin
5-
@testset "4D Input" begin
6-
x = rand(Float32, 2, 2, 3, 4)
7-
m = BatchNorm(3)
5+
@testset "4D Input, $T" for (T,f) in [(Float32, identity), (Float16, f16)]
6+
x = randn(T, 2, 2, 3, 4)
7+
m = f(BatchNorm(3))
88
gx = gpu(x)
99
gm = gpu(m)
1010

1111
y, back = pullback((m, x) -> m(x), m, x)
1212
gy, gback = pullback((m, x) -> m(x), gm, gx)
1313

14-
@test cpu(gy) y
14+
@test cpu(gy) y rtol=1e-3
15+
@test eltype(gy) == T
16+
@test eltype(gm(gx)) == T
1517

16-
Δ = randn(Float32, size(y))
18+
Δ = randn(T, size(y))
1719
dm, dx = back(Δ)
18-
gdm, gdx = gback(gpu(Δ))
20+
gdm, gdx = gback(f(gpu)))
1921

2022
@test dm[].γ cpu(gdm[].γ)
2123
@test dm[].β cpu(gdm[].β)
2224
@test dx cpu(gdx)
25+
@test eltype(gdm[].γ) == T
26+
@test eltype(gdx) == T
2327
end
2428

2529
@testset "2D Input" begin

test/cuda/layers.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,51 @@ end
290290
@test gpu(m).rng isa CUDA.RNG
291291
end
292292
end
293+
294+
@testset "Misc. Float16" begin
295+
# These tests are very far from exhaustive!
296+
297+
x = randn(Float16, 3, 4)
298+
gx = gpu(x)
299+
300+
# Dense
301+
m1 = f16(Dense(3 => 4, tanh))
302+
gm1 = gpu(m1)
303+
304+
y1, back1 = Zygote.pullback(|>, x, m1)
305+
gy1, gback1 = Zygote.pullback(|>, gx, gm1)
306+
307+
@test y1 m1(x) cpu(gy1)
308+
@test eltype(y1) == eltype(m1(x)) == eltype(gy1) == Float16
309+
310+
@test back1(one.(y1))[2].weight cpu(gback1(one.(gy1))[2].weight)
311+
@test eltype(gback1(one.(gy1))[2].bias) == Float16
312+
313+
# A fake loss with Float32
314+
f1(x) = sum((Float32.(x) .- 1).^2)
315+
@test gradient(f1, x)[1] cpu(gradient(f1, gx)[1])
316+
@test eltype(gradient(f1, gx)[1]) == Float16
317+
318+
# Normalisation
319+
m2 = Chain(LayerNorm(3), Dropout(0.1)) |> f16
320+
gm2 = m2 |> gpu
321+
@test m2(x) cpu(gm2(gx))
322+
@test eltype(m2(x)) == Float16
323+
@test eltype(gm2(gx)) == Float16
324+
325+
# Conv
326+
x3 = randn(Float16, 7, 2, 1)
327+
m3 = Conv((3,), 2=>1, sigmoid, pad=1, stride=2) |> f16
328+
@test m3(x3) f16(f32(m3)(f32(x3))) cpu(gpu(m3)(gpu(x3)))
329+
@test eltype(m3(x3)) == Float16
330+
dw = gradient((m,x) -> sum(abs2, m(x)), m3, x3)[1].weight
331+
@test dw f16(gradient((m,x) -> sum(abs2, m(x)), f32(m3), f32(x3))[1].weight)
332+
@test dw cpu(gradient((m,x) -> sum(abs2, m(x)), gpu(m3), gpu(x3))[1].weight)
333+
@test eltype(dw) == Float16
334+
335+
# Pooling
336+
for pool in [MaxPool((2,)), MeanPool((2,))]
337+
pool(reshape(x,3,4,1)) cpu(pool(reshape(gx,3,4,1)))
338+
@test eltype(pool(reshape(gx,3,4,1))) == Float16
339+
end
340+
end

0 commit comments

Comments
 (0)