diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index a9dc1ef2..086bdbef 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -586,6 +586,34 @@ function priority!(::Backend, prio::Symbol) return nothing end +""" + device(::Backend)::Int + +Returns the ordinal number of the currently active device starting at one. +""" +function device(::Backend) + return 1 +end + +""" + ndevices(::Backend)::Int + +Returns the number of devices the backend supports. +""" +function ndevices(::Backend) + return 1 +end + +""" + device!(::Backend, id::Int) +""" +function device!(backend::Backend, id::Int) + if !(0 < id <= ndevices(backend)) + throw(ArgumentError("Device id $id out of bounds.")) + end + return nothing +end + """ functional(::Backend) diff --git a/test/devices.jl b/test/devices.jl new file mode 100644 index 00000000..2dde4bb8 --- /dev/null +++ b/test/devices.jl @@ -0,0 +1,14 @@ +function devices_testsuite(Backend) + backend = Backend() + + current_device = KernelAbstractions.device(backend) + for i in KernelAbstractions.ndevices(backend) + KernelAbstractions.device!(backend, i) + @test KernelAbstractions.device(backend) == i + end + + @test_throws ArgumentError KernelAbstractions.device!(backend, 0) + @test_throws ArgumentError KernelAbstractions.device!(backend, KernelAbstractions.ndevices(backend) + 1) + KernelAbstractions.device!(backend, current_device) + return nothing +end diff --git a/test/testsuite.jl b/test/testsuite.jl index a92cf73a..3828e4ef 100644 --- a/test/testsuite.jl +++ b/test/testsuite.jl @@ -31,6 +31,7 @@ include("private.jl") include("unroll.jl") include("nditeration.jl") include("copyto.jl") +include("devices.jl") include("print_test.jl") include("compiler.jl") include("reflection.jl") @@ -67,6 +68,10 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{ copyto_testsuite(backend, AT) end + @conditional_testset "Devices" skip_tests begin + devices_testsuite(backend) + end + @conditional_testset "Printing" skip_tests begin printing_testsuite(backend) end