Skip to content

Commit 5b22527

Browse files
committed
add get_backend for StaticArrays
1 parent f7a37d0 commit 5b22527

File tree

4 files changed

+27
-1
lines changed

4 files changed

+27
-1
lines changed

ext/StaticArraysExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module StaticArraysExt
2+
3+
import KernelAbstractions: get_backend, CPU
4+
using StaticArrays: SizedArray, MArray, SArray
5+
6+
get_backend(A::SizedArray) = get_backend(A.data)
7+
get_backend(::MArray) = CPU()
8+
# TODO: It makes sense to pass SArray to the GPU backend, so we can't make a determination
9+
# get_backend(::SArray) = CPU()
10+
11+
end

src/KernelAbstractions.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,4 +865,7 @@ if !isdefined(Base, :get_extension)
865865
include("../ext/SparseArraysExt.jl")
866866
end
867867

868+
# Currently we have a direct dependency on StaticArrays
869+
include("../ext/StaticArraysExt.jl")
870+
868871
end #module

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ struct NewBackend <: KernelAbstractions.GPU end
7272
@test_throws MethodError kernel()
7373
end
7474

75-
7675
include("extensions/enzyme.jl")
7776
@static if VERSION >= v"1.7.0"
7877
@testset "Enzyme" begin

test/test.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using InteractiveUtils
44
using LinearAlgebra
55
using SparseArrays
66
using Adapt
7+
using StaticArrays
78

89
identity(x) = x
910

@@ -95,6 +96,18 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk
9596
@test @inferred(KernelAbstractions.get_backend(sparse(A))) isa backendT
9697
end
9798

99+
@conditional_testset "StaticArrays" skip_tests begin
100+
backend = Backend()
101+
backendT = typeof(backend).name.wrapper # To look through CUDABackend{true, false}
102+
@test backend isa backendT
103+
104+
@test KernelAbstractions.get_backend(@MMatrix [1.0]) isa CPU
105+
@test_throws ArgumentError KernelAbstractions.get_backend(@SMatrix [1.0])
106+
107+
A = allocate(backend, Float32, 5, 5)
108+
@test @inferred(KernelAbstractions.get_backend(SizedMatrix{5, 5}(A))) isa backendT
109+
end
110+
98111
@conditional_testset "adapt" skip_tests begin
99112
backend = Backend()
100113
x = allocate(backend, Float32, 5)

0 commit comments

Comments
 (0)