Skip to content

Commit 7409f4a

Browse files
committed
add get_backend for StaticArrays
(cherry picked from commit b8b53da)
1 parent 4fda411 commit 7409f4a

File tree

4 files changed

+25
-2
lines changed

4 files changed

+25
-2
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1616
SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
1717
SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c"
1818
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
19-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2019
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2120
pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd"
2221

2322
[weakdeps]
23+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2424
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
2525
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2626
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -29,6 +29,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2929
EnzymeExt = "EnzymeCore"
3030
LinearAlgebraExt = "LinearAlgebra"
3131
SparseArraysExt = "SparseArrays"
32+
StaticArraysExt = "StaticArrays"
3233

3334
[compat]
3435
Adapt = "0.4, 1.0, 2.0, 3.0, 4"
@@ -50,6 +51,7 @@ julia = "1.10"
5051
pocl_jll = "7"
5152

5253
[extras]
54+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
5355
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
5456
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5557
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

ext/StaticArraysExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module StaticArraysExt
2+
3+
import KernelAbstractions: get_backend, CPU
4+
using StaticArrays: SizedArray, MArray
5+
6+
get_backend(A::SizedArray) = get_backend(A.data)
7+
get_backend(::MArray) = CPU()
8+
9+
end

src/KernelAbstractions.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import PrecompileTools
1212
import Atomix: @atomic, @atomicswap, @atomicreplace
1313

1414
using MacroTools
15-
using StaticArrays
1615
using Adapt
1716

1817
"""

test/test.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using InteractiveUtils
44
using LinearAlgebra
55
using SparseArrays
66
using Adapt
7+
using StaticArrays
78

89
identity(x) = x
910

@@ -95,6 +96,18 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk
9596
@test @inferred(KernelAbstractions.get_backend(sparse(A))) isa backendT
9697
end
9798

99+
@conditional_testset "StaticArrays" skip_tests begin
100+
backend = Backend()
101+
backendT = typeof(backend).name.wrapper # To look through CUDABackend{true, false}
102+
@test backend isa backendT
103+
104+
@test KernelAbstractions.get_backend(@MMatrix [1.0]) isa CPU
105+
@test_throws ArgumentError KernelAbstractions.get_backend(@SMatrix [1.0])
106+
107+
A = allocate(backend, Float32, 5, 5)
108+
@test @inferred(KernelAbstractions.get_backend(SizedMatrix{5, 5}(A))) isa backendT
109+
end
110+
98111
@conditional_testset "adapt" skip_tests begin
99112
backend = Backend()
100113
x = allocate(backend, Float32, 5)

0 commit comments

Comments
 (0)