Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ function (a::Dense)(x::AbstractVecOrMat)
return σ.(a.weight * x .+ a.bias)
end

(a::Dense{typeof(identity), <:AbstractMatrix, Bool})(x::AbstractVecOrMat) =
a.weight * x # fast path, no broadcast

(a::Dense)(x::AbstractArray) =
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

Expand Down Expand Up @@ -246,6 +249,9 @@ function (a::Scale)(x::AbstractArray)
σ.(a.scale .* x .+ a.bias)
end

(a::Scale{typeof(identity), <:AbstractArray, Bool})(x::AbstractArray) =
a.scale .* x

function Base.show(io::IO, l::Scale)
print(io, "Scale(", join(size(l.scale), ", "))
l.σ == identity || print(io, ", ", l.σ)
Expand Down
12 changes: 12 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ function (c::Conv)(x::AbstractArray)
cdims = conv_dims(c, x)
σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c))
end
function (c::Conv{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray)
cdims = conv_dims(c, x)
conv(x, c.weight, cdims) # fast path, no broadcast
end

_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
_channels_out(l::Conv) = size(l.weight, ndims(l.weight))
Expand Down Expand Up @@ -332,6 +336,10 @@ function (c::ConvTranspose)(x::AbstractArray)
cdims = conv_transpose_dims(c, x)
σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c))
end
function (c::ConvTranspose{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray)
cdims = conv_transpose_dims(c, x)
∇conv_data(x, c.weight, cdims) # fast path, no broadcast
end

function Base.show(io::IO, l::ConvTranspose)
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
Expand Down Expand Up @@ -470,6 +478,10 @@ function (c::CrossCor)(x::AbstractArray)
cdims = crosscor_dims(c, x)
σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c))
end
function (c::CrossCor{<:Any,<:Any,typeof(identity),<:AbstractArray,Bool})(x::AbstractArray)
cdims = crosscor_dims(c, x)
crosscor(x, c.weight, cdims) # fast path, no broadcast
end

function Base.show(io::IO, l::CrossCor)
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
Expand Down
18 changes: 9 additions & 9 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ true
```
"""
struct LayerNorm{F,D,T,N}
λ::F
λ::F # this field is not used
diag::D
ϵ::T
size::NTuple{N,Int}
Expand Down Expand Up @@ -254,16 +254,16 @@ function _norm_layer_forward(
end
end

o = _norm_layer_forward(x, μ, σ², l.ϵ)
hasaffine(l) || return l.λ.(o)

γ = reshape(l.γ, affine_shape)
β = reshape(l.β, affine_shape)
return l.λ.(γ .* o .+ β)
s = (inv∘sqrt).(σ² .+ l.ϵ) # faster to un-fuse this, smaller... ideally mean_var(x, ε)?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's unfused by Zygote anyhow, might as well do that here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For just the forward pass, it was still faster to un-fuse this, to do inv & sqrt N times not N^3.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that what your comment is saying? I might be misunderstanding, does "un-fuse" here refer to extracting s as its own variable or to writing s = inv.(sqrt.(σ² .+ l.ϵ)) instead of s = (inv∘sqrt).(σ² .+ l.ϵ)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes maybe we are agreeing. The comment was meant to answer "why make s at all", since without it things got slower. (inv∘sqrt) is probably premature optimisation.

if hasaffine(l)
γ = reshape(l.γ, affine_shape) # ideally reshape on construction, store Scale?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with packing the affine params/activation in a Scale is that batchnorm functions in 3rd party backends (notably cuDNN) expect them to be passed in alongside all the other params. Thus the NNlib-level API has to be batchnorm(x, ..., γ, β), so the Scale only exists as a container to hold the affine params.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. We should probably still make these arrays the size required on construction, and make them even if they won't be used, instead of this:

https://github.com/FluxML/NNlibCUDA.jl/blob/master/src/cudnn/batchnorm.jl#L21

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Flux ever call that NNlib code?

Copy link
Member

@ToucheSir ToucheSir Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the only remaining CUDA.jl-reliant functionality left in this repo aside from the Functors stuff: https://github.com/FluxML/Flux.jl/blob/master/src/cuda/cudnn.jl. Absolute kludge as you can see, which is why these routines should be moved to NNlib sooner than later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I forgot about that file. But I remember seeing it when trying to remove CUDA... agree that NNlib is the right place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a quick blame of the NNlibCUDA line above and came up with FluxML/NNlibCUDA.jl#36. I don't recall why the arrays are allocated instead of just set as CU_NULL before the call. The cuDNN docs don't mention bias and scale params can be null, so maybe that's why. If it turns out they can be and it's just not documented though, we should revisit this.

β = reshape(l.β, affine_shape)
return l.λ.(γ .* s .* (x .- μ) .+ β)
else
return l.λ.(s .* (x .- μ))
end
end

@inline _norm_layer_forward(x, μ, σ², ϵ) = (x .- μ) ./ sqrt.(σ² .+ ϵ)

function _track_stats!(
bn, x::AbstractArray{T, N}, μ, σ², reduce_dims,
) where {T, N}
Expand Down