diff --git a/Project.toml b/Project.toml index 01d4c0de..054d7df5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "BlockArrays" uuid = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -version = "1.4.0" +version = "1.5.0" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -9,12 +9,15 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" [extensions] +BlockArraysAdaptExt = "Adapt" BlockArraysBandedMatricesExt = "BandedMatrices" [compat] +Adapt = "4.3" Aqua = "0.8" ArrayLayouts = "1.0.8" BandedMatrices = "1.0" @@ -30,6 +33,7 @@ Test = "1" julia = "1.10" [extras] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" @@ -42,6 +46,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = [ + "Adapt", "Aqua", "BandedMatrices", "Documenter", diff --git a/ext/BlockArraysAdaptExt.jl b/ext/BlockArraysAdaptExt.jl new file mode 100644 index 00000000..259b23d4 --- /dev/null +++ b/ext/BlockArraysAdaptExt.jl @@ -0,0 +1,14 @@ +module BlockArraysAdaptExt + +using Adapt +using BlockArrays +using BlockArrays: _BlockArray, _BlockedUnitRange +import Adapt: adapt_structure + +adapt_structure(to, r::BlockedUnitRange) = _BlockedUnitRange(adapt(to, r.first), map(adapt(to), r.lasts)) +adapt_structure(to, r::BlockedOneTo) = BlockedOneTo(map(adapt(to), r.lasts)) + +adapt_structure(to, A::BlockArray) = _BlockArray(map(adapt(to), blocks(A)), map(adapt(to), axes(A))) +adapt_structure(to, A::BlockedArray) = BlockedArray(adapt(to, A.blocks), map(adapt(to), axes(A))) + +end diff --git a/test/runtests.jl b/test/runtests.jl index 8a87e547..06d9bebb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,3 +29,4 @@ include("test_blockreduce.jl") include("test_blockdeque.jl") include("test_blockcholesky.jl") include("test_blockbanded.jl") +include("test_adapt.jl") diff --git a/test/test_adapt.jl b/test/test_adapt.jl new file mode 100644 index 00000000..53c61585 --- /dev/null +++ b/test/test_adapt.jl @@ -0,0 +1,38 @@ +module TestBlockArraysAdapt + +using BlockArrays, Adapt, Test + +@testset "Adapt" begin + @testset "Adapt Ranges" begin + @test blockisequal(adapt(Array, blockedrange([2, 3])), blockedrange([2, 3])) + @test blockisequal(adapt(Array, blockedrange(2, [2, 3])), blockedrange(2, [2, 3])) + end + + @testset "Adapt Block Arrays" begin + A = BlockArray(randn(4, 4), [2, 2], [2, 2]) + à = adapt(Array, A) + @test à == A + @test à isa BlockArray{Float64,2} + @test blockisequal(axes(Ã), axes(A)) + V = view(A, :, :) + Ṽ = adapt(Array, V) + @test Ṽ == V + @test Ṽ isa SubArray{Float64,2} + @test parent(Ṽ) isa BlockArray{Float64,2} + @test blockisequal(axes(parent(Ṽ)), axes(A)) + + A = BlockedArray(randn(4, 4), [2, 2], [2, 2]) + à = adapt(Array, A) + @test à == A + @test à isa BlockedArray{Float64,2} + @test blockisequal(axes(Ã), axes(A)) + V = view(A, :, :) + Ṽ = adapt(Array, V) + @test Ṽ == V + @test Ṽ isa SubArray{Float64,2} + @test parent(Ṽ) isa BlockedArray{Float64,2} + @test blockisequal(axes(parent(Ṽ)), axes(A)) + end +end + +end