Skip to content

Commit 520871c

Browse files
committed
Add support for Adapt.jl
1 parent 56b12ca commit 520871c

File tree

6 files changed

+41
-1
lines changed

6 files changed

+41
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ uuid = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
33
version = "0.5.0"
44

55
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
67
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
78
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
89
UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6"
910

1011
[compat]
12+
Adapt = "1"
1113
Requires = "0.5,1"
1214
UnsafeArrays = "1"
1315
julia = "1"

src/ArraysOfArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ __precompile__(true)
44

55
module ArraysOfArrays
66

7+
using Adapt
78
using Requires
89
using Statistics
910
using UnsafeArrays

src/array_of_similar_arrays.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,12 @@ UnsafeArrays.unsafe_uview(A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N} =
201201
ArrayOfSimilarArrays{T,M,N}(uview(A.data))
202202

203203

204+
function Adapt.adapt_structure(to, A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
205+
adapted_data = adapt(to, A.data)
206+
ArrayOfSimilarArrays{eltype(adapted_data),M,N}(adapted_data)
207+
end
208+
209+
204210
function innermap(f::Base.Callable, A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
205211
new_data = map(f, A.data)
206212
U = eltype(new_data)

src/vector_of_arrays.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,16 @@ function UnsafeArrays.uview(A::VectorOfArrays)
397397
end
398398

399399

400+
function Adapt.adapt_structure(to, A::VectorOfArrays)
401+
VectorOfArrays(
402+
adapt(to, A.data),
403+
adapt(to, A.elem_ptr),
404+
adapt(to, A.kernel_size),
405+
no_consistency_checks
406+
)
407+
end
408+
409+
400410
function innermap(f::Base.Callable, A::VectorOfArrays)
401411
new_data = map(f, A.data)
402412
VectorOfArrays(new_data, A.elem_ptr, A.kernel_size, simple_consistency_checks)

test/array_of_similar_arrays.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Test
55

66
using ElasticArrays
77
using UnsafeArrays
8+
using Adapt
89

910
using Statistics
1011
using StatsBase: cov2cor
@@ -160,6 +161,14 @@ using StatsBase: cov2cor
160161
end
161162

162163

164+
@testset "adapt" begin
165+
A_flat = rand(2,3,4,5,6)
166+
A_nested = nestedview(A_flat, 2)
167+
@test @inferred(adapt(identity, A_nested)) == A_nested
168+
@test typeof(adapt(identity, A_nested)) == typeof(A_nested)
169+
end
170+
171+
163172
@testset "deepcopy" begin
164173
A = ArrayOfSimilarArrays{Float64,1}(rand_flat_array(Val(1)))
165174
@test (@inferred deepcopy(A)) == A
@@ -274,7 +283,6 @@ using StatsBase: cov2cor
274283
end
275284
end
276285

277-
278286
@testset "examples" begin
279287
A_flat = rand(2,3,4,5,6)
280288
A_nested = nestedview(A_flat, 2)
@@ -323,6 +331,7 @@ using StatsBase: cov2cor
323331
@test_throws ArgumentError pop!(A_nested)
324332

325333
end
334+
326335
@testset "misc" begin
327336
N = 4
328337
r1 = rand(1,4); r2 = rand(1,4); r3 = rand(1,4); r4 = rand(1,4)

test/vector_of_arrays.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Statistics
66
using Test
77

88
using UnsafeArrays
9+
using Adapt
910

1011
using ArraysOfArrays: full_consistency_checks, append_elemptr!, element_ptr
1112

@@ -222,6 +223,17 @@ using ArraysOfArrays: full_consistency_checks, append_elemptr!, element_ptr
222223
end
223224

224225

226+
@testset "adapt" begin
227+
A1 = VectorOfArrays(ref_AoA1(Float32, 3))
228+
@test @inferred(adapt(identity, A1)) == A1
229+
@test typeof(adapt(identity, A1)) == typeof(A1)
230+
231+
A3 = VectorOfArrays(ref_AoA3(Float32, 3))
232+
@test @inferred(adapt(identity, A3)) == A3
233+
@test typeof(adapt(identity, A3)) == typeof(A3)
234+
end
235+
236+
225237
@testset "examples" begin
226238
VA = @inferred(VectorOfArrays{Float64, 2}())
227239

0 commit comments

Comments
 (0)