@@ -2,8 +2,10 @@ using Flux, Functors, Test, LinearAlgebra, Random, Statistics
2
2
using CUDA
3
3
using CUDAKernels, KernelAbstractions, LoopVectorization, Tullio
4
4
using NeuralAttentionlib
5
+ using NeuralAttentionlib: score_returning
5
6
using BenchmarkTools
6
7
using Flux: glorot_uniform
8
+ using MLUtils
7
9
CUDA. allowscalar (false )
8
10
9
11
const A3{T} = AbstractArray{T, 3 }
@@ -18,19 +20,22 @@ Multi-head dot-product attention layer.
18
20
# Arguments
19
21
20
22
- `dims`: ...
21
- - `nheads `: number of heads
23
+ - `num_heads `: number of heads.
22
24
- `init`: weight initializer for the Dense layers.
23
25
- `bias` : whether pointwise QKVO dense transforms use bias.
24
26
- `attn_dropout_prob`: dropout probability after the self-attention layer
25
27
- `proj_dropout_prob`: dropout probability after the projection layer
26
28
27
29
# Forward
30
+
31
+ (::MultiHeadAttention)(q_in, k_in, v_in; [mask, with_weights])
28
32
29
- - `in_q`: input tensor of shape `(batch_size, seq_len, dims)
30
- - `in_k`: input tensor of shape `(batch_size, seq_len, dims)
31
- - `in_v`: input tensor of shape `(batch_size, seq_len, dims)
32
- - `mask`: input tensor of shape `(batch_size, seq_len, seq_len)`
33
- - `return_weights`: whether to return the attention weights
33
+ - `q_in`: input array of size `( seq_len, dims)
34
+ - `k_in`: input array of size `( seq_len, dims)
35
+ - `v_in`: input array of size `( seq_len, dims)
36
+ - `mask`: input array broadcastable to size
37
+ `(kv_len, q_len, num_heads, batch_size)`. Default `nothing`.
38
+ - `with_weights`: Whether to return the attention weights. Default `false`.
34
39
35
40
# Examples
36
41
68
73
# 8 => (8, 8) => 8
69
74
# (8, 8, 8) => (8, 8) => 8 # (q_in, k_in, v_in) => (qk, v) => out
70
75
mha_process_dims (dims:: Int ) =
71
- (; q_in = dims, k_in = dims, v_in = dims,
72
- qk = dims, v = dims, out = dims)
76
+ (; q_in= dims, k_in= dims, v_in= dims, qk= dims, v= dims, out= dims)
73
77
74
- mha_process_dims ((in, (qkv, out)):: Pair{Int, <:Pair{Int, Int}} ) =
75
- (; q_in = in, k_in = in, v_in = in,
76
- qk = qkv, v = qkv, out)
77
-
78
- mha_process_dims ((in, (qkv, out)):: Pair{<:Tuple, <:Pair{Int, Int}} ) =
79
- (; q_in = in[1 ], k_in = in[2 ], v_in = in[3 ],
80
- qk = qkv, v = qkv, out)
81
-
82
- mha_process_dims ((in, ((qk, v), out)):: Pair{<:Tuple, <:Pair{<:Tuple, Int}} ) =
83
- (; q_in = in[1 ], k_in = in[2 ], v_in = in[3 ], qk, v, out)
84
-
85
- mha_process_dims ((in, ((qk, v), out)):: Pair{Int, <:Pair{<:Tuple, Int}} ) =
86
- (; q_in = in, k_in = in, v_in = in, qk, v, out)
78
+ const TuplInt2 = Union{Int, Tuple{Int, Int}}
79
+ const TuplInt3 = Union{Int, Tuple{Int, Int, Int}}
87
80
81
+ function mha_process_dims ((in, (qkv, out)):: Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}} )
82
+ if in isa Int
83
+ q_in = k_in = v_in = in
84
+ else
85
+ q_in, k_in, v_in = in
86
+ end
87
+ if qkv isa Int
88
+ qk = v = qkv
89
+ else
90
+ qk, v = qkv
91
+ end
92
+ return (; q_in, k_in, v_in, qk, v, out)
93
+ end
88
94
89
95
# self-attention
90
- (m:: MultiHeadAttention )(x; kws... ) = m (x, x, x; kws... )
96
+ (m:: MultiHeadAttention )(qkv; kws... ) = m (qkv, qkv, qkv; kws... )
97
+
98
+ # key and value are the same
99
+ (m:: MultiHeadAttention )(q, kv; kws... ) = m (q, kv, kv; kws... )
91
100
92
- function (m:: MultiHeadAttention )(q_in:: A3 , k_in:: A3 , v_in:: A3 ; with_weights= false , impl= :tullio )
101
+ function (m:: MultiHeadAttention )(q_in:: A3 , k_in:: A3 , v_in:: A3 ;
102
+ with_weights= false , mask= nothing , impl= :tullio )
93
103
# # [q_in] = [q_in_dim, q_len, batch_size]
94
104
# # [k_in] = [k_in_dim, kv_len, batch_size]
95
105
# # [v_in] = [v_in_dim, kv_len, batch_size]
@@ -100,11 +110,9 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=fals
100
110
# [v] = [v_dim, kv_len, batch_size]
101
111
102
112
if impl == :tullio
103
- x, α = dot_product_attention (m. num_heads, q, k, v; dropout= m. attn_drop)
113
+ x, α = dot_product_attention (m. num_heads, q, k, v; mask, dropout= m. attn_drop)
104
114
elseif impl == :nalib
105
- x, α = NeuralAttentionlib. multihead_qkv_attention (
106
- NeuralAttentionlib. score_returning,
107
- m. num_heads, q, k, v)
115
+ x, α = NeuralAttentionlib. multihead_qkv_attention (score_returning, m. num_heads, q, k, v)
108
116
else
109
117
error (" Unknown attention implementation" )
110
118
end
@@ -114,29 +122,41 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=fals
114
122
return with_weights ? (x, α) : x
115
123
end
116
124
117
- # Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html
118
- function dot_product_attention (q:: A4 , k:: A4 , v:: A4 ; dropout= nothing )
119
- α = dot_product_attention_weights (q, k; dropout)
120
- # [α] = [kv_len, q_len, num_heads, batch_size]
121
- @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
122
- # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
123
- return x, α
124
- end
125
+ reshape_heads (x, num_heads) = reshape (x, size (x, 1 ) ÷ num_heads, num_heads, size (x)[2 : end ]. .. )
126
+ flatten_heads (x) = reshape (x, :, size (x)[3 : end ]. .. )
125
127
126
128
function dot_product_attention (num_heads:: Int , q:: A3 , k:: A3 , v:: A3 ; kws... )
127
129
q, k, v = reshape_heads .((q, k, v), num_heads)
128
130
x, α = dot_product_attention (q, k, v; kws... )
129
131
return flatten_heads (x), α
130
132
end
131
133
132
- reshape_heads (x, num_heads) = reshape (x, size (x, 1 ) ÷ num_heads, num_heads, size (x)[2 : end ]. .. )
133
- flatten_heads (x) = reshape (x, :, size (x)[3 : end ]. .. )
134
+ # Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html
135
+ function dot_product_attention (q:: A4 , k:: A4 , v:: A4 ;
136
+ dropout= nothing , bias= nothing , mask= nothing )
137
+
138
+ α = dot_product_attention_weights (q, k; dropout, bias, mask)
139
+ # [α] = [kv_len, q_len, num_heads, batch_size]
140
+ @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
141
+ # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size]
142
+ return x, α
143
+ end
134
144
135
145
function dot_product_attention_weights (q:: A4{T} , k:: A4{T} ;
136
- dropout= nothing ) where T
146
+ dropout= nothing , mask= nothing , bias= nothing ) where T
147
+
137
148
q = q ./ T (√ size (q, 1 ))
138
149
@tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b]
139
150
# [α] = [kv_len, q_len, num_heads, batch_size]
151
+
152
+ if bias != = nothing
153
+ α = α .+ bias
154
+ end
155
+ if mask != = nothing
156
+ neginf = typemin (eltype (α))
157
+ α = ifelse .(mask, α, neginf)
158
+ end
159
+
140
160
α = softmax (α, dims= 1 )
141
161
return dropout === nothing ? α : dropout (α)
142
162
end
@@ -162,6 +182,13 @@ function (proj::QKVProj)(q_in, k_in, v_in)
162
182
return (proj. q_proj (q_in), proj. k_proj (k_in), proj. v_proj (v_in))
163
183
end
164
184
185
+ function make_causal_mask (x:: A3 )
186
+ d, len, batch_size = size (x)
187
+ mask = tril (ones_like (x, (len, len)))
188
+ return mask
189
+ end
190
+
191
+ @non_differentiable make_causal_mask (x)
165
192
166
193
function perf (dim, len, batch_size, num_heads)
167
194
mha = MultiHeadAttention (dim, num_heads)
@@ -222,14 +249,21 @@ test(4, 2, 2, 1)
222
249
223
250
perf (128 , 8 , 128 , 32 )
224
251
# tullio
225
- # 5.862 ms (85 allocations: 6.75 MiB)
226
- # 14.291 ms (1046 allocations: 17.17 MiB)
252
+ # 5.475 ms (80 allocations: 7.25 MiB)
253
+ # 13.073 ms (1172 allocations: 18.18 MiB)
254
+ # tullio - 6 threads
255
+ # 4.818 ms (192 allocations: 7.26 MiB)
256
+ # 10.927 ms (1398 allocations: 18.19 MiB)
227
257
# nalib
228
- # 6.331 ms (90 allocations: 7.75 MiB)
229
- # 16.186 ms (690 allocations: 16.17 MiB)
258
+ # 6.040 ms (91 allocations: 7.75 MiB)
259
+ # 14.542 ms (696 allocations: 16.17 MiB)
260
+ # nalib - 6 threads
261
+ # 7.832 ms (187 allocations: 7.76 MiB)
262
+ # 29.823 ms (988 allocations: 16.19 MiB)
230
263
# tullio - gpu
231
- # 141.365 μs (499 allocations: 22.81 KiB)
232
- # 804.018 μs (2228 allocations: 113.45 KiB)
264
+ # 147.746 μs (523 allocations: 24.59 KiB)
265
+ # 957.111 μs (2413 allocations: 127.88 KiB)
233
266
# nalib - gpu
234
- # 163.487 μs (410 allocations: 18.02 KiB)
235
- # 673.463 μs (1521 allocations: 84.64 KiB)
267
+ # 165.109 μs (411 allocations: 18.05 KiB)
268
+ # 659.685 μs (1527 allocations: 86.09 KiB)
269
+
0 commit comments