Skip to content

Commit 52bee5e

Browse files
causal mask
1 parent 5baf2f0 commit 52bee5e

File tree

1 file changed

+80
-46
lines changed

1 file changed

+80
-46
lines changed

src/layers/attention.jl

Lines changed: 80 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ using Flux, Functors, Test, LinearAlgebra, Random, Statistics
22
using CUDA
33
using CUDAKernels, KernelAbstractions, LoopVectorization, Tullio
44
using NeuralAttentionlib
5+
using NeuralAttentionlib: score_returning
56
using BenchmarkTools
67
using Flux: glorot_uniform
8+
using MLUtils
79
CUDA.allowscalar(false)
810

911
const A3{T} = AbstractArray{T, 3}
@@ -18,19 +20,22 @@ Multi-head dot-product attention layer.
1820
# Arguments
1921
2022
- `dims`: ...
21-
- `nheads`: number of heads
23+
- `num_heads`: number of heads.
2224
- `init`: weight initializer for the Dense layers.
2325
- `bias` : whether pointwise QKVO dense transforms use bias.
2426
- `attn_dropout_prob`: dropout probability after the self-attention layer
2527
- `proj_dropout_prob`: dropout probability after the projection layer
2628
2729
# Forward
30+
31+
(::MultiHeadAttention)(q_in, k_in, v_in; [mask, with_weights])
2832
29-
- `in_q`: input tensor of shape `(batch_size, seq_len, dims)
30-
- `in_k`: input tensor of shape `(batch_size, seq_len, dims)
31-
- `in_v`: input tensor of shape `(batch_size, seq_len, dims)
32-
- `mask`: input tensor of shape `(batch_size, seq_len, seq_len)`
33-
- `return_weights`: whether to return the attention weights
33+
- `q_in`: input array of size `( seq_len, dims)
34+
- `k_in`: input array of size `( seq_len, dims)
35+
- `v_in`: input array of size `( seq_len, dims)
36+
- `mask`: input array broadcastable to size
37+
`(kv_len, q_len, num_heads, batch_size)`. Default `nothing`.
38+
- `with_weights`: Whether to return the attention weights. Default `false`.
3439
3540
# Examples
3641
@@ -68,28 +73,33 @@ end
6873
# 8 => (8, 8) => 8
6974
# (8, 8, 8) => (8, 8) => 8 # (q_in, k_in, v_in) => (qk, v) => out
7075
mha_process_dims(dims::Int) =
71-
(; q_in = dims, k_in = dims, v_in = dims,
72-
qk = dims, v = dims, out = dims)
76+
(; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims)
7377

74-
mha_process_dims((in, (qkv, out))::Pair{Int, <:Pair{Int, Int}}) =
75-
(; q_in = in, k_in = in, v_in = in,
76-
qk = qkv, v = qkv, out)
77-
78-
mha_process_dims((in, (qkv, out))::Pair{<:Tuple, <:Pair{Int, Int}}) =
79-
(; q_in = in[1], k_in = in[2], v_in = in[3],
80-
qk = qkv, v = qkv, out)
81-
82-
mha_process_dims((in, ((qk, v), out))::Pair{<:Tuple, <:Pair{<:Tuple, Int}}) =
83-
(; q_in = in[1], k_in = in[2], v_in = in[3], qk, v, out)
84-
85-
mha_process_dims((in, ((qk, v), out))::Pair{Int, <:Pair{<:Tuple, Int}}) =
86-
(; q_in = in, k_in = in, v_in = in, qk, v, out)
78+
const TuplInt2 = Union{Int, Tuple{Int, Int}}
79+
const TuplInt3 = Union{Int, Tuple{Int, Int, Int}}
8780

81+
function mha_process_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}})
82+
if in isa Int
83+
q_in = k_in = v_in = in
84+
else
85+
q_in, k_in, v_in = in
86+
end
87+
if qkv isa Int
88+
qk = v = qkv
89+
else
90+
qk, v = qkv
91+
end
92+
return (; q_in, k_in, v_in, qk, v, out)
93+
end
8894

8995
# self-attention
90-
(m::MultiHeadAttention)(x; kws...) = m(x, x, x; kws...)
96+
(m::MultiHeadAttention)(qkv; kws...) = m(qkv, qkv, qkv; kws...)
97+
98+
# key and value are the same
99+
(m::MultiHeadAttention)(q, kv; kws...) = m(q, kv, kv; kws...)
91100

92-
function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=false, impl=:tullio)
101+
function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3;
102+
with_weights=false, mask=nothing, impl=:tullio)
93103
## [q_in] = [q_in_dim, q_len, batch_size]
94104
## [k_in] = [k_in_dim, kv_len, batch_size]
95105
## [v_in] = [v_in_dim, kv_len, batch_size]
@@ -100,11 +110,9 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=fals
100110
# [v] = [v_dim, kv_len, batch_size]
101111

102112
if impl == :tullio
103-
x, α = dot_product_attention(m.num_heads, q, k, v; dropout=m.attn_drop)
113+
x, α = dot_product_attention(m.num_heads, q, k, v; mask, dropout=m.attn_drop)
104114
elseif impl == :nalib
105-
x, α = NeuralAttentionlib.multihead_qkv_attention(
106-
NeuralAttentionlib.score_returning,
107-
m.num_heads, q, k, v)
115+
x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.num_heads, q, k, v)
108116
else
109117
error("Unknown attention implementation")
110118
end
@@ -114,29 +122,41 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=fals
114122
return with_weights ? (x, α) : x
115123
end
116124

117-
# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html
118-
function dot_product_attention(q::A4, k::A4, v::A4; dropout=nothing)
119-
α = dot_product_attention_weights(q, k; dropout)
120-
# [α] = [kv_len, q_len, num_heads, batch_size]
121-
@tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
122-
# [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
123-
return x, α
124-
end
125+
reshape_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...)
126+
flatten_heads(x) = reshape(x, :, size(x)[3:end]...)
125127

126128
function dot_product_attention(num_heads::Int, q::A3, k::A3, v::A3; kws...)
127129
q, k, v = reshape_heads.((q, k, v), num_heads)
128130
x, α = dot_product_attention(q, k, v; kws...)
129131
return flatten_heads(x), α
130132
end
131133

132-
reshape_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...)
133-
flatten_heads(x) = reshape(x, :, size(x)[3:end]...)
134+
# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html
135+
function dot_product_attention(q::A4, k::A4, v::A4;
136+
dropout=nothing, bias=nothing, mask=nothing)
137+
138+
α = dot_product_attention_weights(q, k; dropout, bias, mask)
139+
# [α] = [kv_len, q_len, num_heads, batch_size]
140+
@tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
141+
# [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
142+
return x, α
143+
end
134144

135145
function dot_product_attention_weights(q::A4{T}, k::A4{T};
136-
dropout=nothing) where T
146+
dropout=nothing, mask=nothing, bias=nothing) where T
147+
137148
q = q ./ T(size(q, 1))
138149
@tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
139150
# [α] = [kv_len, q_len, num_heads, batch_size]
151+
152+
if bias !== nothing
153+
α = α .+ bias
154+
end
155+
if mask !== nothing
156+
neginf = typemin(eltype(α))
157+
α = ifelse.(mask, α, neginf)
158+
end
159+
140160
α = softmax(α, dims=1)
141161
return dropout === nothing ? α : dropout(α)
142162
end
@@ -162,6 +182,13 @@ function (proj::QKVProj)(q_in, k_in, v_in)
162182
return (proj.q_proj(q_in), proj.k_proj(k_in), proj.v_proj(v_in))
163183
end
164184

185+
function make_causal_mask(x::A3)
186+
d, len, batch_size = size(x)
187+
mask = tril(ones_like(x, (len, len)))
188+
return mask
189+
end
190+
191+
@non_differentiable make_causal_mask(x)
165192

166193
function perf(dim, len, batch_size, num_heads)
167194
mha = MultiHeadAttention(dim, num_heads)
@@ -222,14 +249,21 @@ test(4, 2, 2, 1)
222249

223250
perf(128, 8, 128, 32)
224251
# tullio
225-
# 5.862 ms (85 allocations: 6.75 MiB)
226-
# 14.291 ms (1046 allocations: 17.17 MiB)
252+
# 5.475 ms (80 allocations: 7.25 MiB)
253+
# 13.073 ms (1172 allocations: 18.18 MiB)
254+
# tullio - 6 threads
255+
# 4.818 ms (192 allocations: 7.26 MiB)
256+
# 10.927 ms (1398 allocations: 18.19 MiB)
227257
# nalib
228-
# 6.331 ms (90 allocations: 7.75 MiB)
229-
# 16.186 ms (690 allocations: 16.17 MiB)
258+
# 6.040 ms (91 allocations: 7.75 MiB)
259+
# 14.542 ms (696 allocations: 16.17 MiB)
260+
# nalib - 6 threads
261+
# 7.832 ms (187 allocations: 7.76 MiB)
262+
# 29.823 ms (988 allocations: 16.19 MiB)
230263
# tullio - gpu
231-
# 141.365 μs (499 allocations: 22.81 KiB)
232-
# 804.018 μs (2228 allocations: 113.45 KiB)
264+
# 147.746 μs (523 allocations: 24.59 KiB)
265+
# 957.111 μs (2413 allocations: 127.88 KiB)
233266
# nalib - gpu
234-
# 163.487 μs (410 allocations: 18.02 KiB)
235-
# 673.463 μs (1521 allocations: 84.64 KiB)
267+
# 165.109 μs (411 allocations: 18.05 KiB)
268+
# 659.685 μs (1527 allocations: 86.09 KiB)
269+

0 commit comments

Comments
 (0)