Skip to content

Commit c28b4d3

Browse files
bias is positional argument
1 parent 3bbe779 commit c28b4d3

File tree

1 file changed

+30
-20
lines changed

1 file changed

+30
-20
lines changed

src/attention.jl

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,28 @@ const AA4{T} = AbstractArray{T,4}
33
const AA{N,T} = AbstractArray{T,N}
44

55
"""
6-
dot_product_attention(query, key, value; [bias, fdrop, mask, nheads])
6+
dot_product_attention(query, key, value [bias]; fdrop, mask, nheads])
77
88
Multihead dot product attention used in transformer architectures.
99
1010
The input arrays must have the first two dimensions given by the number of features
11-
and the sequece length, then an arbitrary number of batch dimensions or none.
11+
and the sequece length, then an arbitrary number of batch dimensions or none.
12+
13+
Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores.
14+
of size `(kv_len, q_len, nheads, batch_size...)`.
15+
16+
See also [`dot_product_attention_scores`](@ref) if you only need the attention scores.
1217
1318
# Arguments
1419
1520
- `query`: Query array of size `(qk_dim, q_len, batch_size...)`.
1621
- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`.
1722
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.
18-
- `bias`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
23+
- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
24+
It will be added to the attention scores before applying the softmax. Default `nothing`.
1925
- `fdrop`: A dropout function or layer to apply on the attention scores. Default `identity` (no dropout).
20-
- `mask`: Either `nothing` or an input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
26+
- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
27+
The mask be applied to the attention scores before applying the softmax.
2128
Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`.
2229
- `nheads`: Number of heads to split the input arrays into. Default `1`.
2330
@@ -28,36 +35,37 @@ q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
2835
y, α = dot_product_attention(q, k, v)
2936
```
3037
"""
31-
function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}; nheads=1, kws...) where N
38+
function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N
3239
batch_size = size(q)[3:end]
33-
3440
batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same."))
35-
size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same."))
36-
size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same."))
37-
3841
q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v))
3942

40-
x, α = dot_product_attention(q, k, v; nheads, kws...)
43+
x, α = dot_product_attention(q, k, v, args...; kws...)
4144

4245
x = reshape(x, size(x, 1), size(x, 2), batch_size...)
4346
α = reshape(α, size(α)[1:3]..., batch_size...)
4447
return x, α
4548
end
4649

47-
function dot_product_attention(q::AA3, k::AA3, v::AA3; nheads=1, kws...)
50+
function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing;
51+
fdrop=identity, mask=nothing, nheads=1)
52+
53+
(size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same."))
54+
size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same."))
55+
size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same."))
56+
4857
# Multihead attention. TODO create fastpath for singlehead attention.
4958
q, k, v = split_heads.((q, k, v), nheads)
50-
x, α = _dot_product_attention(q, k, v; kws...)
59+
x, α = _dot_product_attention(q, k, v, bias, fdrop, mask)
5160
return join_heads(x), α
5261
end
5362

54-
function _dot_product_attention(q::AA4, k::AA4, v::AA4;
55-
fdrop=identity, bias=nothing, mask=nothing)
63+
function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask)
5664
# [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size]
5765
# [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size]
5866
# [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size]
59-
60-
α = dot_product_attention_scores(q, k; fdrop, bias, mask)
67+
68+
α = dot_product_attention_scores(q, k, bias; fdrop, mask)
6169
# [α] = [kv_len, q_len, nheads, batch_size]
6270

6371
# The following permutedims and batched_mul are equivalent to
@@ -70,14 +78,16 @@ function _dot_product_attention(q::AA4, k::AA4, v::AA4;
7078
end
7179

7280
"""
73-
dot_product_attention_scores(query, key; [bias, droput_fn, mask])
81+
dot_product_attention_scores(query, key, [bias]; [fdrop, mask])
7482
7583
Return the attention scores for the [`dot_product_attention`](@ref).
84+
Input arrays must have dimensions
85+
`(num_features ÷ nheads, nheads, sequence_length, batch_size)`.
7686
77-
Input arrays must have dimensions `(num_features ÷ nheads, nheads, sequence_length, batch_size)`.
87+
See [`dot_product_attention`](@ref) for more details.
7888
"""
79-
function dot_product_attention_scores(q::AA4{T}, k::AA4{T};
80-
fdrop=identity, mask=nothing, bias=nothing) where T
89+
function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing;
90+
fdrop=identity, mask=nothing) where T
8191

8292
# The following permutedims and batched_mul are equivalent to
8393
# @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim)

0 commit comments

Comments
 (0)