Skip to content

Commit 7ded1e4

Browse files
author
William Moses
committed
WIP: kernels
1 parent 65fcbe8 commit 7ded1e4

File tree

4 files changed

+40
-12
lines changed

4 files changed

+40
-12
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ Scratch = "6c6a2e73-6563-6170-7368-637461726353"
2121
[weakdeps]
2222
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
2323
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
24+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2425
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
2526
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2627
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
2728

2829
[extensions]
2930
ReactantAbstractFFTsExt = "AbstractFFTs"
3031
ReactantArrayInterfaceExt = "ArrayInterface"
32+
ReactantCUDAExt = "CUDA"
3133
ReactantNNlibExt = "NNlib"
3234
ReactantStatisticsExt = "Statistics"
3335
ReactantYaoBlocksExt = "YaoBlocks"
@@ -54,7 +56,8 @@ julia = "1.10"
5456
[extras]
5557
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
5658
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
59+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5760
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
5861

59-
[sources]
60-
ReactantCore = { path = "lib/ReactantCore" }
62+
[sources.ReactantCore]
63+
path = "lib/ReactantCore"

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
33
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
4+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
45
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
56
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
67
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"

test/cuda.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using Reactant
2+
using Test
3+
using CUDA
4+
5+
function square_kernel!(x)
6+
i = threadIdx().x
7+
x[i] *= x[i]
8+
sync_threads()
9+
return nothing
10+
end
11+
12+
# basic squaring on GPU
13+
function square!(x)
14+
@cuda blocks = 1 threads = length(x) square_kernel!(x)
15+
return nothing
16+
end
17+
18+
@testset "Square Kernel" begin
19+
oA = collect(1:1:64)
20+
A = Reactant.to_rarray(oA)
21+
func = @compile square!(A)
22+
@test all(A .≈ (oA .* oA))
23+
end

test/runtests.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,18 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
6060

6161
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
6262
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
63+
@safetestset "CUDA" include("cuda.jl")
6364
@safetestset "AbstractFFTs" include("integration/fft.jl")
6465
end
6566

66-
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
67-
@testset "Neural Networks" begin
68-
@safetestset "NNlib Primitives" include("nn/nnlib.jl")
69-
@safetestset "Flux.jl Integration" include("nn/flux.jl")
70-
if Sys.islinux()
71-
@safetestset "LuxLib Primitives" include("nn/luxlib.jl")
72-
@safetestset "Lux Integration" include("nn/lux.jl")
73-
end
74-
end
75-
end
67+
# if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
68+
# @testset "Neural Networks" begin
69+
# @safetestset "NNlib Primitives" include("nn/nnlib.jl")
70+
# @safetestset "Flux.jl Integration" include("nn/flux.jl")
71+
# if Sys.islinux()
72+
# @safetestset "LuxLib Primitives" include("nn/luxlib.jl")
73+
# @safetestset "Lux Integration" include("nn/lux.jl")
74+
# end
75+
# end
76+
# end
7677
end

0 commit comments

Comments
 (0)