Skip to content

Commit eabcc02

Browse files
address some review comments
1 parent 2193639 commit eabcc02

File tree

1 file changed

+24
-27
lines changed

1 file changed

+24
-27
lines changed

src/attention.jl

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ const AA4{T} = AbstractArray{T,4}
33
const AA{N,T} = AbstractArray{T,N}
44

55
"""
6-
dot_product_attention(query, key, value; [bias, droput_fn, mask, num_heads])
6+
dot_product_attention(query, key, value; [bias, fdrop, mask, nheads])
77
88
Multihead dot product attention used in transformer architectures.
99
@@ -15,12 +15,11 @@ and the sequece length, then an arbitrary number of batch dimensions or none.
1515
- `query`: Query array of size `(qk_dim, q_len, batch_size...)`.
1616
- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`.
1717
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.
18-
- `bias`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, num_heads, batch_size)`.
18+
- `bias`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
19+
- `fdrop`: A dropout function or layer to apply on the attention scores. Default `identity` (no dropout).
20+
- `mask`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
1921
Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`.
20-
- `dropout_fn`: A dropout function to apply on the attention scores. Default `nothing`.
21-
- `mask`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, num_heads, batch_size)`.
22-
Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`.
23-
- `num_heads`: Number of heads to split the input arrays into. Default `1`.
22+
- `nheads`: Number of heads to split the input arrays into. Default `1`.
2423
2524
# Examples
2625
@@ -29,7 +28,7 @@ q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
2928
y, α = dot_product_attention(q, k, v)
3029
```
3130
"""
32-
function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; num_heads=1, kws...) where N
31+
function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; nheads=1, kws...) where N
3332
batch_size = size(q)[3:end]
3433

3534
batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same."))
@@ -39,7 +38,7 @@ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; num_heads=1, kws...
3938
q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v))
4039

4140
# Multihead attention. TODO create fastpath for singlehead attention.
42-
q, k, v = split_heads.((q, k, v), num_heads)
41+
q, k, v = split_heads.((q, k, v), nheads)
4342
x, α = _dot_product_attention(q, k, v; kws...)
4443
x = join_heads(x)
4544

@@ -49,17 +48,17 @@ function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; num_heads=1, kws...
4948
end
5049

5150
function _dot_product_attention(q::AA4, k::AA4, v::AA4;
52-
dropout_fn=nothing, bias=nothing, mask=nothing)
51+
fdrop=nothing, bias=nothing, mask=nothing)
5352

54-
α = dot_product_attention_scores(q, k; dropout_fn, bias, mask)
55-
# [α] = [kv_len, q_len, num_heads, batch_size]
53+
α = dot_product_attention_scores(q, k; fdrop, bias, mask)
54+
# [α] = [kv_len, q_len, nheads, batch_size]
5655

5756
# The following permutedims and batched_mul are equivalent to
5857
# @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
5958
vt = permutedims(v, (1, 3, 2, 4))
6059
x = batched_mul(vt, α)
6160
x = permutedims(x, (1, 3, 2, 4))
62-
# [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
61+
# [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size]
6362
return x, α
6463
end
6564

@@ -68,35 +67,33 @@ end
6867
6968
Return the attention scores for the [`dot_product_attention`](@ref).
7069
71-
Input arrays must have dimensions `(num_features ÷ num_heads, num_heads, sequence_length, batch_size)`
70+
Input arrays must have dimensions `(num_features ÷ nheads, nheads, sequence_length, batch_size)`.
7271
7372
"""
7473
function dot_product_attention_scores(q::AA4{T}, k::AA4{T};
75-
dropout_fn=nothing, mask=nothing, bias=nothing) where T
74+
fdrop=identity, mask=nothing, bias=nothing) where T
7675

77-
q = q ./ T(size(q, 1))
78-
7976
# The following permutedims and batched_mul are equivalent to
80-
# @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
77+
# @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim)
8178
kt = permutedims(k, (3, 1, 2, 4))
82-
qt = permutedims(q, (1, 3, 2, 4))
83-
α = batched_mul(kt, qt)
84-
# [α] = [kv_len, q_len, num_heads, batch_size]
79+
qt = permutedims(q, (1, 3, 2, 4)) ./ T(size(q, 1))
80+
logits = batched_mul(kt, qt)
81+
# [logits] = [kv_len, q_len, nheads, batch_size]
8582

8683
if bias !== nothing
87-
α = α .+ bias
84+
logits = logits .+ bias
8885
end
8986

9087
if mask !== nothing
9188
if mask === :causal
92-
mask = make_causal_mask(α)
89+
mask = make_causal_mask(logits)
9390
end
94-
neginf = typemin(eltype(α))
95-
α = ifelse.(mask, α, neginf)
91+
neginf = typemin(eltype(logits))
92+
logits = ifelse.(mask, logits, neginf)
9693
end
9794

98-
α = softmax(α, dims=1)
99-
return dropout_fn === nothing ? α : dropout_fn(α)
95+
α = softmax(logits, dims=1)
96+
return fdrop(α)
10097
end
10198

10299
"""
@@ -116,7 +113,7 @@ end
116113
trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true)
117114
falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false)
118115

119-
split_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...)
116+
split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...)
120117
join_heads(x) = reshape(x, :, size(x)[3:end]...)
121118

122119
@non_differentiable make_causal_mask(x)

0 commit comments

Comments
 (0)