Skip to content

Commit 7fed9cc

Browse files
committed
Use package extensions
1 parent 89f8f89 commit 7fed9cc

8 files changed

+109
-61
lines changed

Project.toml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
99
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1010

11+
[weakdeps]
12+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
13+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
14+
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
15+
16+
[extensions]
17+
ArraysOfArraysAdaptExt = "Adapt"
18+
ArraysOfArraysChainRulesCoreExt = "ChainRulesCore"
19+
ArraysOfArraysStaticArraysCoreExt = "StaticArraysCore"
20+
1121
[compat]
1222
Adapt = "1, 2, 3, 4"
1323
ChainRulesCore = "1"
@@ -16,11 +26,12 @@ Statistics = "1"
1626
julia = "1.6"
1727

1828
[extras]
29+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1930
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
2031
ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
2132
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2233
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2334
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2435

2536
[targets]
26-
test = ["ChainRulesTestUtils", "ElasticArrays", "StaticArrays", "StatsBase", "Test"]
37+
test = ["Adapt", "ChainRulesTestUtils", "ElasticArrays", "StaticArrays", "StatsBase", "Test"]

ext/ArraysOfArraysAdaptExt.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# This file is a part of ArraysOfArrays.jl, licensed under the MIT License (MIT).
2+
3+
module ArraysOfArraysAdaptExt
4+
5+
import Adapt
6+
using Adapt: adapt
7+
8+
using ArraysOfArrays: ArrayOfSimilarArrays, VectorOfArrays
9+
using ArraysOfArrays: no_consistency_checks
10+
11+
12+
function Adapt.adapt_structure(to, A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
13+
adapted_data = adapt(to, A.data)
14+
ArrayOfSimilarArrays{eltype(adapted_data),M,N}(adapted_data)
15+
end
16+
17+
18+
function Adapt.adapt_structure(to, A::VectorOfArrays)
19+
VectorOfArrays(
20+
adapt(to, A.data),
21+
adapt(to, A.elem_ptr),
22+
adapt(to, A.kernel_size),
23+
no_consistency_checks
24+
)
25+
end
26+
27+
28+
end # module ArraysOfArraysAdaptExt
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# This file is a part of ArraysOfArrays.jl, licensed under the MIT License (MIT).
2+
3+
module ArraysOfArraysChainRulesCoreExt
4+
5+
import ChainRulesCore
6+
using ChainRulesCore: NoTangent, unthunk
7+
8+
using ArraysOfArrays: ArrayOfSimilarArrays
9+
using ArraysOfArrays: flatview
10+
11+
12+
function _aosa_ctor_fromflat_pullback(ΔΩ)
13+
NoTangent(), flatview(convert(ArrayOfSimilarArrays, unthunk(ΔΩ)))
14+
end
15+
16+
function ChainRulesCore.rrule(::Type{ArrayOfSimilarArrays{T,M,N}}, flat_data::AbstractArray{U,L}) where {T,M,N,L,U}
17+
return ArrayOfSimilarArrays{T,M,N}(flat_data), _aosa_ctor_fromflat_pullback
18+
end
19+
20+
_aosa_ctor_fromnested_pullback(ΔΩ) = NoTangent(), ΔΩ
21+
22+
function ChainRulesCore.rrule(::Type{ArrayOfSimilarArrays{T,M,N}}, A::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U}
23+
return ArrayOfSimilarArrays{T,M,N}(A), _aosa_ctor_fromnested_pullback
24+
end
25+
26+
27+
function ChainRulesCore.rrule(::typeof(flatview), A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
28+
function flatview_pullback(ΔΩ)
29+
data = unthunk(ΔΩ)
30+
NoTangent(), ArrayOfSimilarArrays{eltype(data),M,N}(data)
31+
end
32+
33+
return flatview(A), flatview_pullback
34+
end
35+
36+
37+
end # module ArraysOfArraysChainRulesCoreExt
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# This file is a part of ArraysOfArrays.jl, licensed under the MIT License (MIT).
2+
3+
module ArraysOfArraysStaticArraysCoreExt
4+
5+
import StaticArraysCore
6+
using StaticArraysCore: StaticArray, SVector
7+
8+
import ArraysOfArrays
9+
using ArraysOfArrays: nestedview
10+
11+
12+
@inline ArraysOfArrays.flatview(A::AbstractArray{SA,N}) where {S,T,M,N,SA<:StaticArray{S,T,M}} =
13+
reshape(reinterpret(T, A), size(SA)..., size(A)...)
14+
15+
16+
@inline function ArraysOfArrays.nestedview(A::AbstractArray{T}, SA::Type{SVector{S,T}}) where {T,S}
17+
size_A = size(A)
18+
size_A[1] == S || throw(DimensionMismatch("Length $S of static vector type does not match first dimension of array of size $size_A"))
19+
reshape(reinterpret(SA, A), ArraysOfArrays._tail(size_A)...)
20+
end
21+
22+
@inline ArraysOfArrays.nestedview(A::AbstractArray{T}, ::Type{SVector{S}}) where {T,S} =
23+
nestedview(A, SVector{S,T})
24+
25+
26+
end # module ArraysOfArraysStaticArraysCoreExt

src/ArraysOfArrays.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@ ArraysOfArrays provides two different types of nested arrays:
1313
"""
1414
module ArraysOfArrays
1515

16-
using Adapt
1716
using Statistics
18-
using ChainRulesCore
19-
20-
import StaticArraysCore
2117

2218
include("util.jl")
2319
include("functions.jl")
2420
include("array_of_similar_arrays.jl")
2521
include("vector_of_arrays.jl")
26-
include("arrays_of_static_arrays.jl")
22+
23+
@static if !isdefined(Base, :get_extension)
24+
include("../ext/ArraysOfArraysAdaptExt.jl")
25+
include("../ext/ArraysOfArraysChainRulesCoreExt.jl")
26+
include("../ext/ArraysOfArraysStaticArraysCoreExt.jl")
27+
end
2728

2829
end # module

src/array_of_similar_arrays.jl

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -85,25 +85,11 @@ end
8585

8686
export ArrayOfSimilarArrays
8787

88-
function _aosa_ctor_fromflat_pullback(ΔΩ)
89-
NoTangent(), flatview(convert(ArrayOfSimilarArrays, unthunk(ΔΩ)))
90-
end
91-
92-
function ChainRulesCore.rrule(::Type{ArrayOfSimilarArrays{T,M,N}}, flat_data::AbstractArray{U,L}) where {T,M,N,L,U}
93-
return ArrayOfSimilarArrays{T,M,N}(flat_data), _aosa_ctor_fromflat_pullback
94-
end
95-
9688
function ArrayOfSimilarArrays{T,M,N}(A::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U}
9789
B = ArrayOfSimilarArrays{T,M,N}(Array{T}(undef, innersize(A)..., size(A)...))
9890
copyto!(B, A)
9991
end
10092

101-
_aosa_ctor_fromnested_pullback(ΔΩ) = NoTangent(), ΔΩ
102-
103-
function ChainRulesCore.rrule(::Type{ArrayOfSimilarArrays{T,M,N}}, A::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U}
104-
return ArrayOfSimilarArrays{T,M,N}(A), _aosa_ctor_fromnested_pullback
105-
end
106-
10793
ArrayOfSimilarArrays{T}(A::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U} =
10894
ArrayOfSimilarArrays{T,M,N}(A)
10995

@@ -143,15 +129,6 @@ the result may be freely changed without breaking the inner consistency of
143129
"""
144130
flatview(A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N} = A.data
145131

146-
function ChainRulesCore.rrule(::typeof(flatview), A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
147-
function flatview_pullback(ΔΩ)
148-
data = unthunk(ΔΩ)
149-
NoTangent(), ArrayOfSimilarArrays{eltype(data),M,N}(data)
150-
end
151-
152-
return flatview(A), flatview_pullback
153-
end
154-
155132

156133
Base.size(A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N} = split_tuple(size(A.data), Val{M}())[2]
157134

@@ -216,13 +193,6 @@ end
216193
Base.prepend!(dest::ArrayOfSimilarArrays{T,M,N}, src::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U} =
217194
prepend!(dest, ArrayOfSimilarArrays(src))
218195

219-
220-
function Adapt.adapt_structure(to, A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
221-
adapted_data = adapt(to, A.data)
222-
ArrayOfSimilarArrays{eltype(adapted_data),M,N}(adapted_data)
223-
end
224-
225-
226196
function innermap(f::Base.Callable, A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
227197
new_data = map(f, A.data)
228198
U = eltype(new_data)

src/arrays_of_static_arrays.jl

Lines changed: 0 additions & 15 deletions
This file was deleted.

src/vector_of_arrays.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -398,16 +398,6 @@ function Base.empty!(A::VectorOfArrays)
398398
end
399399

400400

401-
function Adapt.adapt_structure(to, A::VectorOfArrays)
402-
VectorOfArrays(
403-
adapt(to, A.data),
404-
adapt(to, A.elem_ptr),
405-
adapt(to, A.kernel_size),
406-
no_consistency_checks
407-
)
408-
end
409-
410-
411401
function innermap(f::Base.Callable, A::VectorOfArrays)
412402
new_data = map(f, A.data)
413403
VectorOfArrays(new_data, A.elem_ptr, A.kernel_size, simple_consistency_checks)

0 commit comments

Comments
 (0)