Skip to content

Commit 1b313f2

Browse files
[ci skip] updates
1 parent 8ac9e6b commit 1b313f2

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
88
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
11+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1314
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"

src/layers/attention.jl

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
using Flux, Test, LinearAlgebra, Random, Statistics
2-
using CUDA, CUDAKernels, LoopVectorization
1+
using Flux, Functors, Test, LinearAlgebra, Random, Statistics
2+
using CUDA, CUDAKernels, KernelAbstractions, LoopVectorization
33
using Tullio
44
using NeuralAttentionlib
55
using BenchmarkTools
6-
6+
CUDA.allowscalar(false)
77
const A3{T} = AbstractArray{T, 3}
88

99
"""
@@ -144,11 +144,6 @@ function perf(dim, len, batch_size, num_heads)
144144
mha = MultiHeadAttention(dim, num_heads)
145145
x = rand(Float32, (dim, len, batch_size))
146146

147-
y = mha(x, x, x)
148-
@test y isa Array{Float32, 3}
149-
@test size(y) == (dim, len, batch_size)
150-
151-
152147
println("tullio")
153148
@btime $mha($x, v=:tullio);
154149
@btime gradient(m -> sum(m($x, v=:tullio)), $mha);
@@ -172,4 +167,29 @@ function perf(dim, len, batch_size, num_heads)
172167
return nothing
173168
end
174169

175-
perf(64, 100, 32, 8)
170+
function test(dim, len, batch_size, num_heads)
171+
mha = MultiHeadAttention(dim, num_heads)
172+
x = rand(Float32, (dim, len, batch_size))
173+
y = mha(x, v=:tullio)
174+
@test y isa Array{Float32, 3}
175+
@test size(y) == (dim, len, batch_size)
176+
y2 = mha(x, v=:nnalib)
177+
@test size(y) == size(y2)
178+
@test y2 y
179+
180+
if CUDA.functional()
181+
mha_gpu = mha |> gpu
182+
x_gpu = x |> gpu
183+
184+
y_gpu = mha_gpu(x_gpu, v=:tullio)
185+
y_gpu2 = mha_gpu(x_gpu, v=:nnalib)
186+
@test Array(y_gpu) Array(y_gpu2)
187+
@test Array(y_gpu) y
188+
end
189+
return nothing
190+
end
191+
192+
193+
test(12, 3, 2, 4)
194+
195+
perf(64, 100, 32, 4)

0 commit comments

Comments
 (0)