diff --git a/Project.toml b/Project.toml index 5884014..0fee8e5 100644 --- a/Project.toml +++ b/Project.toml @@ -3,10 +3,14 @@ uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658" version = "0.1.1" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] diff --git a/README.md b/README.md index f4d10da..d5a5842 100644 --- a/README.md +++ b/README.md @@ -37,3 +37,5 @@ As will any features which migrate to Flux itself. * More advanced [`train!` function](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/train.jl) * Macro for [making custom layers](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/compact.jl) quickly * Experimental [`apply(c::Chain, x)`](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/chain.jl) interface +* [Pre-allocated](https://github.com/FluxML/Fluxperimental.jl/blob/master/src/preallocated.jl) +working space for some layers diff --git a/src/Fluxperimental.jl b/src/Fluxperimental.jl index 91438b0..ebe1cd7 100644 --- a/src/Fluxperimental.jl +++ b/src/Fluxperimental.jl @@ -8,9 +8,11 @@ export Split, Join include("train.jl") export shinkansen! - include("chain.jl") include("compact.jl") +include("preallocated.jl") +export pre, nopre + end # module Fluxperimental diff --git a/src/preallocated.jl b/src/preallocated.jl new file mode 100644 index 0000000..139520c --- /dev/null +++ b/src/preallocated.jl @@ -0,0 +1,346 @@ +using Flux, ChainRulesCore +using LinearAlgebra: mul! +using FastBroadcast: @.. +using Strided + +const NoT = NoTangent() + +""" + PreLayer(Dense(2 => 3, relu)) + +Stores, along with the layer, pre-allocated space for its output, +and all gradient components. Only works on layers it understands. +""" +struct PreLayer{L,G,V} + layer::L + grad::G # same fixed sizes as layer + fwd::V # vector of dynamic length + rev::V +end + +Flux.@functor PreLayer +Flux.trainable(p::PreLayer) = (; layer = p.layer) + +""" + model |> pre + +Wrap as many layers as possible with `PreLayer`, +to store pre-allocated space for output & gradient. +Ignores layers it doesn't understand. +""" +pre(model) = fmap(PreLayer, model; exclude = x -> hasmethod(PreLayer, Tuple{typeof(x)})) + +""" + nopre(model) + +Remove all `PreLayer`s & return the plain model. +""" +nopre(model) = fmap(x -> x.layer, model; exclude = x -> x isa PreLayer) + + +##### +##### Dense +##### + +function PreLayer(d::Dense) + grad = _struct_sim(d) + fwd, rev = similar(d.weight, 0), similar(d.weight, 0) + PreLayer(d, grad, fwd, rev) +end + +function (p::PreLayer{<:Dense})(x::AbstractMatrix{<:Real}) + y, dx = _pre_setup(p, x) + _densecall!(y, p, x, dx) +end + +function _pre_setup(p::PreLayer{<:Dense}, x) # this function @nograd + _, b = size(x) + o, i = size(p.layer.weight) + if o*b != length(p.fwd) + resize!(p.fwd, o*b) + resize!(p.rev, i*b) + end + y = _pre_reshape(p.fwd, (o,b)) + dx = _pre_reshape(p.rev, (i,b)) + (; y, dx) +end + +function _densecall!(y, p, x, dx) + y .= p.layer.bias + mul!(y, p.layer.weight, x, true, true) + act!(y, p.layer.σ) + y +end + +function ChainRulesCore.rrule(::typeof(_densecall!), y, p, x, dx) + y = _densecall!(y, p, x, dx) + function back(dy) + dy = unthunk(dy) + dy = ∇act!(y, dy, p.layer.σ) + # layer + weight = mul!(p.grad.weight, dy, x') + bias = ∇bias!(p.grad.bias, dy) + tang = Tangent{Dense}(; weight, bias) + # input + dx = mul!(dx, p.layer.weight', dy) + return (NoT, NoT, Tangent{PreLayer}(; layer = tang), dx, NoT) + end + y, back +end + +##### +##### Scale +##### + +scale!(y, (scale, ds), (x, dx), (bias, db)) = y .= scale .* x .+ bias +# scale!(y, (scale, ds), (x, dx), (bias, db)) = @strided y .= scale .* x .+ bias + +function ChainRulesCore.rrule(::typeof(scale!), y, (scale, ds), (x, dx), (bias, db)) + y = scale!(y, (scale, ds), (x, dx), (bias, db)) + function back(dy) + dy = unthunk(dy) + @strided dx .= dy .* scale + @strided ds .= dy .* x + dbias = ∇bias!(bias, db) + return (NoT, NoT, (ds, NoT), (dx, NoT), (dbias, NoT)) + end + y, back +end + +##### +##### Conv +##### + +function PreLayer(c::Conv) + grad = _struct_sim(c) + fwd, rev = similar(c.weight, 0), similar(c.weight, 0) + PreLayer(c, grad, fwd, rev) +end + +function (p::PreLayer{<:Conv})(x::AbstractArray{<:Real}) + y, dx = _pre_setup(p, x) + _convcall!(y, p, x, dx) +end + +using Flux: conv_dims, conv_reshape_bias +using Flux.NNlib: fast_act, conv!, output_size, channels_out + +function _pre_setup(p::PreLayer{<:Conv}, x) + cdims = conv_dims(p.layer, x) + ysize = (output_size(cdims)..., channels_out(cdims), size(x)[end]) + if prod(ysize) != length(p.fwd) + resize!(p.fwd, prod(ysize)) + resize!(p.rev, length(x)) + end + y = _pre_reshape(p.fwd, ysize) + dx = _pre_reshape(p.rev, size(x)) + (; y, dx) +end + +function _convcall!(y, p, x, dx) + cdims = conv_dims(p.layer, x) + conv!(y, x, p.layer.weight, cdims) + if p.layer.bias isa AbstractArray + y .+= conv_reshape_bias(p.layer) + end + act!(y, fast_act(p.layer.σ, x)) +end + +# function ChainRulesCore.rrule(::typeof(_convcall!), y, p, x, dx) +# y = _densecall!(y, p, x, dx) +# function back(dy) +# dy = unthunk(dy) +# dy = ∇act!(y, dy, p.layer.σ) +# # layer +# weight = mul!(p.grad.weight, dy, x') +# bias = ∇bias!(p.grad.bias, dy) +# tang = Tangent{Dense}(; weight, bias) +# # input +# dx = mul!(dx, p.layer.weight', dy) +# return (NoT, NoT, Tangent{PreLayer}(; layer = tang), dx, NoT) +# end +# y, back +# end + + + +##### +##### BatchNorm +##### + +function PreLayer(bn::BatchNorm) + grad = (β = similar(bn.β), γ = similar(bn.γ)) # only trainable fields + fwd, rev = zeros(Float32, 0), zeros(Float32, 0) # not ideal + PreLayer(bn, grad, fwd, rev) +end + +function (p::PreLayer{<:BatchNorm})(x::AbstractArray{<:Real}) + y, dx = _pre_setup(p, x) + # _batchnormcall!(y, p, x, dx) + + # from (BN::BatchNorm)(x) + N = ndims(x) + reduce_dims = [1:N-2; N] + affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + _norm_layer_forward!(y, p, (x, dx); reduce_dims, affine_shape) +end + +using Flux: _isactive, _track_stats!, hasaffine + +function _norm_layer_forward!(y, p, (x, dx); reduce_dims, affine_shape) + l = p.layer + N = ndims(x) + + # This block verbatim from Flux. However, mean & var aren't in-place, + # nor are their gradients... add more storage? + + if !_isactive(l) && l.track_stats # testmode with tracked stats + stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + μ = reshape(l.μ, stats_shape) + σ² = reshape(l.σ², stats_shape) + else # trainmode or testmode without tracked stats + μ = mean(x; dims=reduce_dims) + σ² = var(x; mean=μ, dims=reduce_dims, corrected=false) + if l.track_stats + _track_stats!(l, x, μ, σ², reduce_dims) # update moving mean/std + end + end + + y = _norm_layer_forward!(y, x, dx, μ, σ², l.ϵ) + hasaffine(l) || return act!(y, l.λ) + + γ = reshape(l.γ, affine_shape) + β = reshape(l.β, affine_shape) + # return l.λ.(γ .* y .+ β) + y2 = scale!(y, (γ, p.grad.γ), (x, dx), (β, p.grad.β)) + return act!(y2, l.λ) +end + +_norm_layer_forward!(y, x, dx, μ, σ², ϵ) = y .= (x .- μ) ./ sqrt.(σ² .+ ϵ) +# _norm_layer_forward!(y, x, dx, μ, σ², ϵ) = @strided y .= (x .- μ) ./ sqrt.(σ² .+ ϵ) + +function ChainRulesCore.rrule(::typeof(_norm_layer_forward!), y, x, dx, μ, σ², ϵ) + y = _norm_layer_forward!(y, x, dx, μ, σ², ϵ) + function back(dy) + dx .= dy ./ sqrt.(σ² .+ ϵ) + # TODO write gradients for mean & variance, these are WRONG! + dμ = NoT + dσ² = NoT + return (NoT, NoT, dx, NoT, dμ, dσ², NoT) + end + y, back +end + +##### +##### softmax +##### + +function PreLayer(::typeof(softmax)) + fwd, rev = zeros(Float32, 0), zeros(Float32, 0) # not ideal, demands `model |> pre |> gpu` + PreLayer(softmax, nothing, fwd, rev) +end + +function (p::PreLayer{typeof(softmax)})(x::AbstractArray{<:Real}) + y, dx = _pre_setup(p, x) # generic version + _softmaxcall!(y, p, x, dx) +end + +_softmaxcall!(y, p, x, dx) = softmax!(y, x) + +function ChainRulesCore.rrule(::typeof(_softmaxcall!), y, p, x, dx) + y = _softmaxcall!(y, p, x, dx) + function back(dy) + # TODO: CHECK THIS! + dx .= dy .* y + dx .= dx .- y .* sum(dx; dims=1) # could sum! into the end of rev + return (NoT, NoT, NoT, dx, NoT) # last one could be NotImplemented? + end + y, back +end + + +##### +##### activation functions +##### + +act!(y, ::typeof(identity)) = y +function act!(y, act::F) where F + σ = Flux.NNlib.fast_act(act, y) + # y .= σ.(y) + # Unfortunately this hits https://github.com/JuliaLang/julia/issues/43153 + # maybe you could patch Strided.jl to avoid it? Or use another package... + # @strided y .= σ.(y) + @.. y = σ(y) +end + +# Piracy, disable @strided on CuArrays: +Strided.maybestrided(x::Flux.CuArray) = x + +# For this rule, it's important to use what `act!` returns, not what it mutates +ChainRulesCore.rrule(::typeof(act!), y, f) = act!(y, f), dz -> (NoT, ∇act!(y, dy, f), NoT) + +∇act!(y, dy, ::typeof(identity)) = dy +∇act!(y, dy, ::typeof(relu)) = @.. y = ifelse(y>0, dy, 0f0) +∇act!(y, dy, ::typeof(tanh)) = @.. y = (1 - y^2) +∇act!(y, dy, ::typeof(sigmoid)) = @.. y = y * (1 - y) + + +function PreLayer(::typeof(relu)) + fwd, rev = zeros(Float32, 0), zeros(Float32, 0) # not ideal + PreLayer(relu, nothing, fwd, rev) +end + +function (p::PreLayer{typeof(relu)})(x::AbstractArray{<:Real}) + y, dx = _pre_setup(p, x) # generic version + _relucall!(y, p, x, dx) +end + +_relucall!(y, p, x, dx) = y .= relu.(x) + +function ChainRulesCore.rrule(::typeof(_relucall!), y, p, x, dx) + y = _relucall!(y, p, x, dx) + function back(dy) + @. dx = ifelse(y>0, dy, 0f0) + return (NoT, NoT, NoT, dx, NoT) + end + y, back +end + +##### +##### PreLayer utils +##### + +_struct_sim(x) = Flux.fmapstructure(x) do x + x isa AbstractArray{<:Real} ? similar(x) : nothing +end + +function _pre_setup(p::PreLayer, x) # generic version + if length(x) != length(p.fwd) + resize!(p.fwd, length(x)) + resize!(p.rev, length(x)) + end + y = _pre_reshape(p.fwd, size(x)) + dx = _pre_reshape(p.rev, size(x)) + (; y, dx) +end +ChainRulesCore.@non_differentiable _pre_setup(::Any, ::Any) + +# Cannot use reshape(::Array), as that prevents later resize! +_pre_reshape(x::Array, size::Tuple) = Base.ReshapedArray(x, size, ()) +# _pre_reshape(x::Array, size::Tuple) = Base.__reshape((x, Base.IndexStyle(x)), size) # what Base does, no better +# Must use reshape(::CuArray) as mul! rejects ReshapedArray +_pre_reshape(x::Flux.CuArray, size::Tuple) = reshape(x, size) +_pre_reshape(x, size::Tuple) = reshape(x, size) + +# Base piracy! to prevent ReshapedArray from going missing +Base._reshape(R::Base.ReshapedArray, dims::Base.Dims) = Base.ReshapedArray(R.parent, dims, ()) + +∇bias!(::Bool, dx) = NoT +∇bias!(bias, dx) = sum!(bias, dx) + +function Base.show(io::IO, p::PreLayer) + show(io, p.layer) + printstyled(io, " |> pre", color=:blue) +end + +Flux._show_children(p::PreLayer) = Flux._show_children(p.layer) diff --git a/test/preallocated.jl b/test/preallocated.jl new file mode 100644 index 0000000..1f01349 --- /dev/null +++ b/test/preallocated.jl @@ -0,0 +1,138 @@ + +m1 = Chain(Dense(784 => 32, relu), Dense(32 => 10), softmax) +m2 = m1 |> pre + +x = randn(Float32, 784, 64); + +@test m1(x) ≈ m2(x) + +g1 = gradient((m,x) -> m(x)[1], m1, x) +g2 = gradient((m,x) -> m(x)[1], m2, x) + +@test g1[1].layers[1].bias ≈ g2[1].layers[1].layer.bias +@test g1[2] ≈ g2[2] + + +#= + +julia> @btime gradient((m,x) -> m(x)[1], $m1, $x); + min 50.167 μs, mean 88.796 μs (58 allocations, 355.41 KiB) + +julia> @btime gradient((m,x) -> m(x)[1], $m2, $x); + min 57.792 μs, mean 66.050 μs (115 allocations, 17.75 KiB) + + + +let data = [(x,) for _ in 1:1000] + o1 = Flux.setup(Adam(), m1) + @btime Flux.train!((m,x) -> m(x)[1], $m1, $data, $o1) + + o2 = Flux.setup(Adam(), m2) + @btime Flux.train!((m,x) -> m(x)[1], $m2, $data, $o2) + + nothing +end + +# Yesterday: +# min 1.799 s, mean 1.802 s (177001 allocations, 352.94 MiB) +# min 146.713 ms, mean 251.041 ms (295001 allocations, 25.71 MiB) + +# Today, wtf? Maybe threading changes have hurt. +# min 244.235 ms, mean 251.582 ms (177001 allocations, 352.94 MiB) +# min 224.760 ms, mean 227.594 ms (301001 allocations, 26.02 MiB) + + +m1cu = m1 |> gpu +m2cu = m2 |> gpu +xcu = x |> gpu + + +let data = [(xcu,) for _ in 1:1000] + o1 = Flux.setup(Adam(), m1cu) + CUDA.@time Flux.train!((m,x) -> sum(m(x)), m1cu, data, o1) + + o2 = Flux.setup(Adam(), m2cu) + CUDA.@time Flux.train!((m,x) -> sum(m(x)), m2cu, data, o2) + + nothing +end +# 1.280640 seconds (1.86 M CPU allocations: 111.723 MiB, 10.99% gc time) (17.00 k GPU allocations: 340.008 MiB, 8.80% memmgmt time) +# 1.327849 seconds (1.73 M CPU allocations: 112.376 MiB, 6.70% gc time) (3.00 k GPU allocations: 2.689 MiB, 2.29% memmgmt time) + + +=# + + +m3 = Chain(Dense(784 => 1024, tanh), BatchNorm(1024), Dense(1024 => 10), softmax) +m4 = m3 |> pre + +x = randn(Float32, 784, 64); + +@test m3(x) ≈ m4(x) + +@btime $m3($x); +@btime $m4($x); + +#= + +julia> @btime $m3($x); + min 318.000 μs, mean 7.944 ms (31 allocations, 1.01 MiB) + +julia> @btime $m4($x); + min 410.459 μs, mean 440.106 μs (57 allocations, 3.55 KiB) + +=# + +x4 = randn(Float32, 28, 28, 1, 13); + +m5 = @autosize (size(x4)...,) Chain( + Conv((3,3), 1 => 7, relu, stride=2, pad=1), + Conv((3,3), _ => 9, relu, stride=2), + Conv((3,3), _ => 5, tanh, stride=2, bias=false), + Flux.flatten, + Dense(_ => 10), + ) +m6 = m5 |> pre + +@test m5(x4) ≈ m6(x4) + +#= + +julia> @btime $m5($x4); + min 139.125 μs, mean 191.653 μs (179 allocations, 262.73 KiB) + +julia> @btime $m6($x4); + min 140.125 μs, mean 196.337 μs (160 allocations, 86.39 KiB) + +=# + + +using Metalhead +m50 = Metalhead.ResNet(50) # 100MB +m50pre = m50 |> pre # 200BM + + +# First run + +julia> @time m50(randn(Float32, 100,100,3,32)) |> size + 5.543590 seconds (6.11 M allocations: 1.963 GiB, 14.14% gc time, 96.22% compilation time) +(1000, 32) + +julia> @time m50pre(randn(Float32, 100,100,3,32)) |> size + 16.098089 seconds (15.84 M allocations: 2.576 GiB, 62.26% gc time, 69.06% compilation time) +(1000, 32) + +# Later + + +julia> @time m50(randn(Float32, 100,100,3,32)) |> size + 11.541100 seconds (4.40 k allocations: 1.570 GiB, 85.73% gc time) +(1000, 32) + +julia> @time m50pre(randn(Float32, 100,100,3,32)) |> size + 4.664626 seconds (4.09 k allocations: 381.454 MiB, 61.15% gc time) +(1000, 32) + + +m50pre # now 1.340 GiB +