diff --git a/src/softmax.jl b/src/softmax.jl index 182f2fb93..1527bbfaf 100644 --- a/src/softmax.jl +++ b/src/softmax.jl @@ -90,6 +90,170 @@ end fast_maximum(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf)) +""" + fast_exp(x) + +For `x::Float32`, this is a much faster (about 20x) +but much less accurate (about 0.1%) version of `exp`. +All other real numbers call `@fastmath exp(x)`. + +Handles `Inf` but not `NaN`: +``` +julia> xs = Tuple([0, 1, Inf32, -Inf32, NaN32]); + +julia> fast_exp.(xs) +(1.0017247f0, 2.717878f0, Inf32, 0.0f0, Inf32) + +julia> exp.(xs) +(1.0f0, 2.7182817f0, Inf32, 0.0f0, NaN32) +``` +""" +@inline function fast_exp(x::Float32) + t = x * 1.442695041f0 + i = unsafe_trunc(Int32, t) - signbit(t) + f = t - i + f2 = evalpoly(f, (1.00172476f0, 0.657636276f0, 0.3371894346f0)) + y = reinterpret(Float32, reinterpret(Int32, f2) + (i << 23)) + ifelse(x < -87.33655f0, 0.0f0, ifelse(x < 88.72283f0, y, Inf32)) +end +# Adapted from code by njuffa which claims /* max. rel. error <= 1.73e-3 on [-87,88] */ +# https://stackoverflow.com/questions/10552280/fast-exp-calculation-possible-to-improve-accuracy-without-losing-too-much-perfo/10792321#10792321 + +# Direct translation to Float16, similar accuracy, twice as fast? +@inline function fast_exp(x::Float16) + t = x * Float16(1.442) + i = unsafe_trunc(Int16, t) - signbit(t) + f = t - i + f2 = evalpoly(f, (Float16(1.002), Float16(0.6577), Float16(0.3372))) + y = reinterpret(Float16, reinterpret(Int16, f2) + (i << 10)) + ifelse(x < Float16(-9.7), Float16(-0.0), ifelse(x < Float16(11.09), y, Inf16)) +end + +fast_exp(x::Real) = @fastmath exp(x) + +#= + +julia> let x = randn(Float32, 1000) + y = similar(x) + @btime $y .= exp.($x) + @btime @fastmath $y .= exp.($x) + @btime @turbo $y .= exp.($x) + @btime $y .= NNlib.fast_exp.($x) + end; + min 3.938 μs, mean 3.984 μs (0 allocations) + min 3.422 μs, mean 3.450 μs (0 allocations) + min 459.812 ns, mean 462.233 ns (0 allocations) + min 249.777 ns, mean 251.146 ns (0 allocations) + + 14.190 μs (0 allocations: 0 bytes) # another computer + 12.435 μs (0 allocations: 0 bytes) + 1.311 μs (0 allocations: 0 bytes) + 553.774 ns (0 allocations: 0 bytes) + +julia> let x = CUDA.randn(Float32, 100, 100_000) + y = similar(x) + @btime CUDA.@sync $y .= exp.($x) + @btime CUDA.@sync @fastmath $y .= exp.($x) + @btime CUDA.@sync $y .= NNlib.fast_exp.($x) + end; + 124.673 μs (27 allocations: 1.36 KiB) + 124.202 μs (27 allocations: 1.36 KiB) + 124.066 μs (27 allocations: 1.36 KiB) + +=# + +export fast_softmax + +""" + fast_softmax(x; dims=1) + +For `x::AbstractArray{Float32}`, this is a faster but less accurate `softmax`. + +Mean error 0.01% on `x = randn(Float32, ...)`, +about 4 decimal digits worse than `softmax`. +About 5x faster. + +# Example +``` +julia> [fast_softmax([-Inf32,1,2,3]) softmax([-Inf32,1,2,3])] # OK with -Inf +4×2 Matrix{Float32}: + 0.0 0.0 + 0.0898185 0.0900306 + 0.244652 0.244728 + 0.66553 0.665241 + +julia> [fast_softmax([1,Inf32]) softmax([1,Inf32])] # does not handle +Inf +2×2 Matrix{Float32}: + 0.0 0.0 + NaN 1.0 +``` +""" +fast_softmax(x::AbstractArray{T}; dims = 1) where {T} = fast_softmax!(similar(x, float(T)), x; dims) +function fast_softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T} + max_ = fast_maximum(x; dims) + out .= fast_exp.(x .- max_) + tmp = dims isa Colon ? sum(out) : sum!(max_, out) + return out ./= tmp +end + +function rrule(::typeof(fast_softmax), x; dims = 1) + y = fast_softmax(x; dims) + softmax_pullback(dy) = (NoTangent(), ∇softmax_data(unthunk(dy), y; dims)) + return y, softmax_pullback +end + +#= + +julia> let x = randn(Float32, 100, 1000) # CPU + y = similar(x) + @btime softmax!($y, $x) + @btime NNlib.fast_softmax!($y, $x) + end; + min 647.000 μs, mean 657.488 μs (1 allocation, 4.06 KiB) + min 133.917 μs, mean 139.647 μs (1 allocation, 4.06 KiB) + + 1.646 ms (1 allocation: 4.06 KiB) # another computer + 322.792 μs (1 allocation: 4.06 KiB) + +julia> let x = CUDA.rand(Float32, 100, 1000) # same (small) size + y = similar(x) + @btime CUDA.@sync softmax!($y, $x) + @btime CUDA.@sync NNlib.fast_softmax!($y, $x) # faster because it skips a launch + end; + 151.148 μs (262 allocations: 12.94 KiB) + 78.955 μs (153 allocations: 7.50 KiB) + +# removing all(isfinite, max_) check, the full-precision softmax! is as fast: + 79.720 μs (153 allocations: 7.50 KiB) + 80.410 μs (153 allocations: 7.50 KiB) + +julia> let x = CUDA.rand(Float32, 100, 10_000) # 10 times bigger + y = similar(x) + @btime CUDA.@sync softmax!($y, $x) + @btime CUDA.@sync NNlib.fast_softmax!($y, $x) + end; + 205.560 μs (262 allocations: 12.94 KiB) + 150.375 μs (153 allocations: 7.50 KiB) + +# removing all(isfinite, max_) check: + 149.104 μs (153 allocations: 7.50 KiB) + 149.570 μs (153 allocations: 7.50 KiB) + +julia> let x = CUDA.rand(Float32, 100, 100_000) # 100 times bigger + y = similar(x) + @btime CUDA.@sync softmax!($y, $x) + @btime CUDA.@sync NNlib.fast_softmax!($y, $x) + end; + 1.673 ms (309 allocations: 15.66 KiB) # difference is noise I think + 1.729 ms (200 allocations: 10.27 KiB) + +# removing all(isfinite, max_) check: + 1.740 ms (200 allocations: 10.27 KiB) + 1.708 ms (200 allocations: 10.27 KiB) + +=# + + """ logsoftmax(x; dims = 1)