Skip to content

Commit 198846b

Browse files
committed
support multiple devices per backend (#554)
1 parent c3010d6 commit 198846b

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

src/KernelAbstractions.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,34 @@ function priority!(::Backend, prio::Symbol)
586586
return nothing
587587
end
588588

589+
"""
590+
device(::Backend)::Int
591+
592+
Returns the ordinal number of the currently active device starting at one.
593+
"""
594+
function device(::Backend)
595+
return 1
596+
end
597+
598+
"""
599+
ndevices(::Backend)::Int
600+
601+
Returns the number of devices the backend supports.
602+
"""
603+
function ndevices(::Backend)
604+
return 1
605+
end
606+
607+
"""
608+
device!(::Backend, id::Int)
609+
"""
610+
function device!(backend::Backend, id::Int)
611+
if !(0 < id <= ndevices(backend))
612+
throw(ArgumentError("Device id $id out of bounds."))
613+
end
614+
return nothing
615+
end
616+
589617
"""
590618
functional(::Backend)
591619

test/devices.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
function devices_testsuite(Backend)
2+
backend = Backend()
3+
4+
current_device = KernelAbstractions.device(backend)
5+
for i in KernelAbstractions.ndevices(backend)
6+
KernelAbstractions.device!(backend, i)
7+
@test KernelAbstractions.device(backend) == i
8+
end
9+
10+
@test_throws ArgumentError KernelAbstractions.device!(backend, 0)
11+
@test_throws ArgumentError KernelAbstractions.device!(backend, KernelAbstractions.ndevices(backend) + 1)
12+
KernelAbstractions.device!(backend, current_device)
13+
return nothing
14+
end

test/testsuite.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ include("private.jl")
3131
include("unroll.jl")
3232
include("nditeration.jl")
3333
include("copyto.jl")
34+
include("devices.jl")
3435
include("print_test.jl")
3536
include("compiler.jl")
3637
include("reflection.jl")
@@ -67,6 +68,10 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{
6768
copyto_testsuite(backend, AT)
6869
end
6970

71+
@conditional_testset "Devices" skip_tests begin
72+
devices_testsuite(backend)
73+
end
74+
7075
@conditional_testset "Printing" skip_tests begin
7176
printing_testsuite(backend)
7277
end

0 commit comments

Comments
 (0)