Skip to content

Commit 255e4ce

Browse files
Fix sub_materialize for GPU arrays (#261)
Currently, `sub_materialize` (through `sub_materialize_axes`) falls back to materializing on CPU. This PR generalizes that logic by determining the output destination with `similar`, which helps to support non-Array types like GPU arrays. As a stand-in for other GPU arrays, I test this using [JLArrays.JLArray](https://github.com/JuliaGPU/GPUArrays.jl/tree/master/lib/JLArrays), which is a reference implementation for the GPUArrays.jl interface that runs on CPU. An alternative design would be to define memory layouts for GPU arrays (i.e. #9), which would allow more customizability for GPU array backends, however I think it is helpful to have fallbacks that "just work" if reasonable parts of the Base AbstractArray interface are implemented. I hit this issue because I was testing out [`BlockArrays.BlockedArray`](https://juliaarrays.github.io/BlockArrays.jl/stable/man/blockedarrays) wrapping a GPU array and noticed that calling `A[Block(1, 1)]` to access a block instantiated the block on CPU, this PR fixes that issue. --------- Co-authored-by: Sheehan Olver <[email protected]>
1 parent 067e9c4 commit 255e4ce

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "1.11.2"
4+
version = "1.12.0"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
8-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
98
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1010

1111
[weakdeps]
1212
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -18,6 +18,7 @@ ArrayLayoutsSparseArraysExt = "SparseArrays"
1818
Aqua = "0.8"
1919
FillArrays = "1.2.1"
2020
Infinities = "0.1"
21+
JLArrays = "0.2"
2122
LinearAlgebra = "1"
2223
Quaternions = "0.7"
2324
Random = "1"
@@ -30,11 +31,12 @@ julia = "1.10"
3031
[extras]
3132
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3233
Infinities = "e1ba4f0e-776d-440f-acd9-e1d2e9742647"
34+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
3335
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
3436
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3537
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3638
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3739
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3840

3941
[targets]
40-
test = ["Aqua", "Infinities", "Quaternions", "Random", "StableRNGs", "SparseArrays", "Test"]
42+
test = ["Aqua", "Infinities", "JLArrays", "Quaternions", "Random", "StableRNGs", "SparseArrays", "Test"]

src/ArrayLayouts.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ include("triangular.jl")
127127
include("factorizations.jl")
128128

129129
# Extend this function if you're only looking to dispatch on the axes
130-
@inline sub_materialize_axes(V, _) = Array(V)
130+
@inline sub_materialize_axes(V, _) = copyto!(similar(V, axes(V)), V)
131131
@inline sub_materialize(_, V, ax) = sub_materialize_axes(V, ax)
132132
@inline sub_materialize(L, V) = sub_materialize(L, V, axes(V))
133133
@inline sub_materialize(V::SubArray) = sub_materialize(MemoryLayout(V), V)

test/test_layouts.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module TestLayouts
22

3-
using ArrayLayouts, LinearAlgebra, FillArrays, Test
3+
using ArrayLayouts, LinearAlgebra, FillArrays, JLArrays, Test
44
import ArrayLayouts: MemoryLayout, DenseRowMajor, DenseColumnMajor, StridedLayout,
55
ConjLayout, RowMajor, ColumnMajor, UnitStride,
66
SymmetricLayout, HermitianLayout, UpperTriangularLayout,
@@ -404,6 +404,16 @@ struct FooNumber <: Number end
404404
@test ArrayLayouts.mul((1:11)', F) isa AbstractMatrix{Int}
405405
end
406406

407+
@testset "GPUArrays/JLArrays" begin
408+
A = jl(randn(5,5))
409+
@test MemoryLayout(A) == DenseColumnMajor()
410+
@test ArrayLayouts.layout_getindex(A,1:3,1:3) == A[1:3,1:3]
411+
@test ArrayLayouts.layout_getindex(A,1:3,1:3) isa JLArray{Float64}
412+
V = view(A,1:3,1:3)
413+
@test ArrayLayouts.sub_materialize(V) == A[1:3,1:3]
414+
@test ArrayLayouts.sub_materialize(V) isa JLArray{Float64}
415+
end
416+
407417
@testset "Triangular col/rowsupport" begin
408418
A = randn(5,5)
409419
@test colsupport(UpperTriangular(A),3) Base.OneTo(3)

0 commit comments

Comments
 (0)