Skip to content

Commit d647153

Browse files
committed
1 parent 83fb3d7 commit d647153

File tree

4 files changed

+1551
-1531
lines changed

4 files changed

+1551
-1531
lines changed

Project.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1616
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1717
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1818
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
19+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1920
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2021
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2122

2223
[extensions]
2324
StructArraysAdaptExt = "Adapt"
24-
StructArraysGPUArraysCoreExt = "GPUArraysCore"
25+
StructArraysGPUArraysCoreExt = ["GPUArraysCore", "KernelAbstractions"]
2526
StructArraysLinearAlgebraExt = "LinearAlgebra"
2627
StructArraysSparseArraysExt = "SparseArrays"
2728
StructArraysStaticArraysExt = "StaticArrays"
@@ -32,10 +33,11 @@ Aqua = "0.8"
3233
ConstructionBase = "1"
3334
DataAPI = "1"
3435
Documenter = "1"
35-
GPUArraysCore = "0.1.2, 0.2"
36+
GPUArraysCore = "0.2"
3637
InfiniteArrays = "0.13"
37-
JLArrays = "0.1"
38+
JLArrays = "0.2"
3839
LinearAlgebra = "1"
40+
KernelAbstractions = "0.9"
3941
OffsetArrays = "1"
4042
PooledArrays = "1"
4143
SparseArrays = "1"
@@ -54,6 +56,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
5456
InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c"
5557
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
5658
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
59+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
5760
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
5861
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
5962
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -63,4 +66,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
6366
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
6467

6568
[targets]
66-
test = ["Adapt", "Aqua", "Documenter", "GPUArraysCore", "InfiniteArrays", "JLArrays", "LinearAlgebra", "OffsetArrays", "PooledArrays", "SparseArrays", "StaticArrays", "Test", "TypedTables", "WeakRefStrings"]
69+
test = ["Adapt", "Aqua", "Documenter", "GPUArraysCore", "InfiniteArrays", "JLArrays", "LinearAlgebra", "KernelAbstractions", "OffsetArrays", "PooledArrays", "SparseArrays", "StaticArrays", "Test", "TypedTables", "WeakRefStrings"]

ext/StructArraysAdaptExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
module StructArraysAdaptExt
22
# Use Adapt allows for automatic conversion of CPU to GPU StructArrays
33
using Adapt, StructArrays
4-
Adapt.adapt_structure(to, s::StructArray) = replace_storage(adapt(to), s)
4+
5+
function Adapt.adapt_structure(to, s::StructArray)
6+
@info "AAA"
7+
@show s
8+
replace_storage(adapt(to), s)
9+
end
510
end

ext/StructArraysGPUArraysCoreExt.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,23 @@ using StructArrays: map_params, array_types
66
using Base: tail
77

88
import GPUArraysCore
9+
import KernelAbstractions as KA
10+
11+
function KA.get_backend(x::T) where {T<:StructArray}
12+
components = StructArrays.components(x)
13+
array_components = filter(
14+
fn -> getfield(components, fn) isa AbstractArray,
15+
fieldnames(typeof(components)))
16+
backends = map(
17+
fn -> KA.get_backend(getfield(components, fn)),
18+
array_components)
919

10-
# for GPU broadcast
11-
import GPUArraysCore
12-
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
13-
backends = map_params(GPUArraysCore.backend, array_types(T))
1420
backend, others = backends[1], tail(backends)
1521
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
1622
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
1723
return backend
1824
end
25+
1926
StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true
2027

2128
end # module

0 commit comments

Comments
 (0)