Skip to content
8 changes: 8 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ tanhshrink
trelu
```

## Attention

```@docs
dot_product_attention
dot_product_attention_scores
make_causal_mask
```

## Softmax

`Flux`'s `logitcrossentropy` uses `NNlib.softmax` internally.
Expand Down
3 changes: 3 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ for f in ACTIVATIONS
end
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu # Aliases

include("attention.jl")
export dot_product_attention, dot_product_attention_scores, make_causal_mask

include("dropout.jl")
export dropout, dropout!

Expand Down
137 changes: 137 additions & 0 deletions src/attention.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
const AA3{T} = AbstractArray{T,3}
const AA4{T} = AbstractArray{T,4}
const AA{N,T} = AbstractArray{T,N}

"""
dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads])
Multihead dot product attention used in transformer architectures.
The input arrays must have the first two dimensions given by the number of features
and the sequence length, then an arbitrary number of batch dimensions or none.
Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores
of size `(kv_len, q_len, nheads, batch_size...)`.
See also [`dot_product_attention_scores`](@ref) if you only need the attention scores.
# Arguments
- `query`: Query array of size `(qk_dim, q_len, batch_size...)`.
- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`.
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.
- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
It will be added to the attention scores before applying the softmax. Default `nothing`.
- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax.
Default `identity` (no dropout).
- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
The mask is applied to the attention scores just before the softmax.
See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`.
- `nheads`: Number of heads to split the input arrays into. Default `1`.
# Examples
```julia
q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
y, α = dot_product_attention(q, k, v)
```
"""
function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N
batch_size = size(q)[3:end]
batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same."))
q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v))

x, α = dot_product_attention(q, k, v, args...; kws...)

x = reshape(x, size(x, 1), size(x, 2), batch_size...)
α = reshape(α, size(α)[1:3]..., batch_size...)
return x, α
end

function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing;
fdrop=identity, mask=nothing, nheads=1)

(size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same."))
size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same."))
size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same."))

# Multihead attention. TODO create fastpath for singlehead attention.
q, k, v = split_heads.((q, k, v), nheads)
x, α = _dot_product_attention(q, k, v, bias, fdrop, mask)
return join_heads(x), α
end

function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask)
# [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size]
# [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size]
# [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size]

α = dot_product_attention_scores(q, k, bias; fdrop, mask)
# [α] = [kv_len, q_len, nheads, batch_size]

# The following permutedims and batched_mul are equivalent to
# @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
vt = permutedims(v, (1, 3, 2, 4))
x = batched_mul(vt, α)
x = permutedims(x, (1, 3, 2, 4))
# [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size]
return x, α
end

"""
dot_product_attention_scores(query, key, [bias]; [fdrop, mask])
Return the attention scores for the [`dot_product_attention`](@ref).
Input arrays must have dimensions
`(num_features ÷ nheads, nheads, sequence_length, batch_size)`.
See [`dot_product_attention`](@ref) for more details.
"""
function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing;
fdrop=identity, mask=nothing) where T

# The following permutedims and batched_mul are equivalent to
# @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim)
kt = permutedims(k, (3, 1, 2, 4))
qt = permutedims(q, (1, 3, 2, 4)) ./ T(size(q, 1))
logits = batched_mul(kt, qt)
# [logits] = [kv_len, q_len, nheads, batch_size]

if bias !== nothing
logits = logits .+ bias
end

if mask !== nothing
neginf = typemin(eltype(logits))
logits = ifelse.(mask, logits, neginf)
end
Copy link
Member

@ToucheSir ToucheSir Jan 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about making these internal methods which dispatch on nothing? That way there's zero control flow and Zygote is happy. The main question is whether the additional code + complexity introduced would be worth the compile and runtime reduction.


α = softmax(logits, dims=1)
return fdrop(α)
end

"""
make_causal_mask(x, dims=2)
Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`.
Its elements are set such that `m[i, j] == i ≤ j`.
Can be used to mask the attention scores in [`dot_product_attention`](@ref).
"""
function make_causal_mask(x::AbstractArray; dims::Int=2)
len = size(x, dims)
mask = triu(trues_like(x, (len, len)))
return mask
end

trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true)
falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false)

split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...)
join_heads(x) = reshape(x, :, size(x)[3:end]...)

@non_differentiable make_causal_mask(x)
@non_differentiable trues_like(::Any...)
@non_differentiable falses_like(::Any...)

15 changes: 13 additions & 2 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ _unbatch(A::BatchedAdjOrTrans) = parent(A)
batched_mul(A, B) -> C
A ⊠ B # \\boxtimes

Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`.
If `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`.
Batched matrix multiplication. Result has `C[:,:,k...] == A[:,:,k...] * B[:,:,k...]` where `k...` represent
any indices in the last dimensions.

If `ndims(A) == ndims(B) == 3` and `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`.

To transpose each matrix, apply `batched_transpose` to the array,
or `batched_adjoint` for conjugate-transpose:
Expand Down Expand Up @@ -42,6 +44,15 @@ This will be copied, as doing so is faster than `batched_mul_generic!`.
Both this `copy` and `batched_mul_generic!` produce `@debug` messages,
and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them.
"""
function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My vote is to make this an internal _batched_mul_4 or something for now. Partly because I think explaining what does and doesn't work becomes more complicated with this method. And that doesn't have to be solved to add attention.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a pity to not make things available. Maybe I can leave the previous docstring unchanged and add a new one for the new method?

batch_size = size(x)[3:end]
@assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
x2 = reshape(x, size(x, 1), size(x, 2), :)
y2 = reshape(y, size(y, 1), size(y, 2), :)
z = batched_mul(x2, y2)
return reshape(z, size(z, 1), size(z, 2), batch_size...)
end

function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2}
size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 ||
throw(DimensionMismatch("batch size mismatch: A != B"))
Expand Down
2 changes: 1 addition & 1 deletion src/gemm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ for (gemm, elt) in gemm_datatype_mappings

end

C
return C
end
end
end
76 changes: 76 additions & 0 deletions test/attention.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
@testset "different batchsizes" begin
n = 15
lenq = 3
lenkv = 4
for batch_size in [(), 1, 2, (2,1,3)], nheads in [1, 3, 5]
q = rand(Float32, n, lenq, batch_size...)
k = rand(Float32, n, lenkv, batch_size...)
v = rand(Float32, n, lenkv, batch_size...)
y, α = dot_product_attention(q, k, v; nheads)
@test y isa Array{Float32}
@test size(y) == (n, lenq, batch_size...)
@test size(α) == (lenkv, lenq, nheads, batch_size...)
@test sum(α, dims=1) ones(1, lenq, nheads, batch_size...)
end
end

@testset "dot_product_attention_scores" begin
q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24
α = dot_product_attention_scores(q, k)
q2, k2 = reshape.((q, k), 8, 3, 1)
y, α2 = dot_product_attention(q2, k2, k2; nheads=2)
@test α α2
end

@testset "specific results" begin
q = k = v = reshape([1:12;], 4, 3, 1) ./ 12
y, α = dot_product_attention(q, k, v; nheads=2)
ytrue = [0.4297536645089624, 0.5130869978422957, 0.6137914555895531, 0.6971247889228864, 0.46431026790247376, 0.5476436012358071, 0.6478764227436047, 0.731209756076938, 0.49773020657887745, 0.5810635399122107, 0.6804545876711346, 0.763787921004468]
ytrue = reshape(ytrue, 4, 3, 1)
αtrue = [0.3138955704910261, 0.3329478654910607, 0.35315656401791323, 0.264431440679808, 0.32820631493296265, 0.4073622443872293, 0.21921458153690657, 0.31838021718955445, 0.4624052012735389, 0.2886914482847165, 0.33124273666190807, 0.3800658150533755, 0.24123865285082136, 0.3238934260675431, 0.43486792108163547, 0.19843756756539277, 0.31176110185581074, 0.4898013305787966]
αtrue = reshapetrue, 3, 3, 2, 1)
@test y ytrue
@test α αtrue
end

@testset "mask" begin
q = rand(4, 2, 3, 1)
k = rand(4, 2, 5, 1)

mask = rand(Bool, (5, 3))
α = dot_product_attention_scores(q, k; mask)
@test all((α[:,:,1,1].> 0) .== mask)
@test all((α[:,:,2,1].> 0) .== mask)

@testset "causal" begin
x = rand(4, 2, 3, 1)
mask = make_causal_mask(x, dims=3)
α = dot_product_attention_scores(x, x; mask)
@test all((α[:,:,1,1].> 0) .== mask)
@test all((α[:,:,2,1].> 0) .== mask)
end
end

@testset "dropout" begin
q = k = v = rand(10, 10, 10)
fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p)
y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5))
@test 0.6 > mean(>(0), α) > 0.4
end

@testset "bias" begin
q = rand(4, 5, 1)
k = v = rand(4, 3, 1)
bias = randn(3, 5)
y, α = dot_product_attention(q, k, v, bias; nheads=2)
@test size(α) == (3, 5, 2, 1)
@test size(y) == (4, 5, 1)
end

@testset "gradient" begin
q = rand(4, 5, 1)
k = v = rand(4, 3, 1)
bias = randn(3, 5)
y, α = dot_product_attention(q, k, v, bias; nheads=2)
gradtest((x...) -> dot_product_attention(x...; nheads=2)[1], q, k, v, bias)
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ include("test_utils.jl")
include("activations.jl")
end

@testset "Attention" begin
include("attention.jl")
end

@testset "Batched Multiplication" begin
include("batchedmul.jl")
end
Expand Down