Skip to content

Commit c5a508a

Browse files
committed
Add rrules for ArrayOfSimilarArrays ctors
1 parent dd701b9 commit c5a508a

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

src/array_of_similar_arrays.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,29 +70,40 @@ struct ArrayOfSimilarArrays{
7070
data::P
7171

7272
function ArrayOfSimilarArrays{T,M,N}(flat_data::AbstractArray{U,L}) where {T,M,N,L,U}
73-
size_inner, size_outer = split_tuple(size(flat_data), Val{M}())
7473
require_ndims(flat_data, _add_vals(Val{M}(), Val{N}()))
7574
conv_parent = _convert_elype(T, flat_data)
7675
P = typeof(conv_parent)
7776
new{T,M,N,L,P}(conv_parent)
7877
end
78+
end
7979

80-
function ArrayOfSimilarArrays{T,M}(flat_data::AbstractArray{U,L}) where {T,M,L,U}
81-
size_inner, size_outer = split_tuple(size(flat_data), Val{M}())
82-
N = length(size_outer)
83-
conv_parent = _convert_elype(T, flat_data)
84-
P = typeof(conv_parent)
85-
new{T,M,N,L,P}(conv_parent)
86-
end
80+
function ArrayOfSimilarArrays{T,M}(flat_data::AbstractArray{U,L}) where {T,M,L,U}
81+
_, size_outer = split_tuple(size(flat_data), Val{M}())
82+
N = length(size_outer)
83+
ArrayOfSimilarArrays{T,M,N}(flat_data)
8784
end
8885

8986
export ArrayOfSimilarArrays
9087

88+
function _aosa_ctor_fromflat_pullback(ΔΩ)
89+
NoTangent(), flatview(convert(ArrayOfSimilarArrays, unthunk(ΔΩ)))
90+
end
91+
92+
function ChainRulesCore.rrule(::Type{ArrayOfSimilarArrays{T,M,N}}, flat_data::AbstractArray{U,L}) where {T,M,N,L,U}
93+
return ArrayOfSimilarArrays{T,M,N}(flat_data), _aosa_ctor_fromflat_pullback
94+
end
95+
9196
function ArrayOfSimilarArrays{T,M,N}(A::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U}
9297
B = ArrayOfSimilarArrays{T,M,N}(Array{T}(undef, innersize(A)..., size(A)...))
9398
copyto!(B, A)
9499
end
95100

101+
_aosa_ctor_fromnested_pullback(ΔΩ) = NoTangent(), ΔΩ
102+
103+
function ChainRulesCore.rrule(::Type{ArrayOfSimilarArrays{T,M,N}}, A::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U}
104+
return ArrayOfSimilarArrays{T,M,N}(A), _aosa_ctor_fromnested_pullback
105+
end
106+
96107
ArrayOfSimilarArrays{T}(A::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U} =
97108
ArrayOfSimilarArrays{T,M,N}(A)
98109

test/array_of_similar_arrays.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Test
66
using ElasticArrays
77
using UnsafeArrays
88
using Adapt
9+
using ChainRulesTestUtils
910

1011
using Statistics
1112
using StatsBase: cov2cor
@@ -86,6 +87,8 @@ using StatsBase: cov2cor
8687
test_from_flat(VectorOfSimilarArrays{Float32}, VectorOfSimilarArrays{Float32,2,3,Array{Float32,3}}, Val(3))
8788
test_from_flat(VectorOfSimilarVectors{Float32}, VectorOfSimilarVectors{Float32,Array{Float32,2}}, Val(2))
8889
test_from_flat(VectorOfSimilarVectors{Float32}, VectorOfSimilarVectors{Float32,Array{Float32,2}}, Val(2))
90+
91+
test_rrule(ArrayOfSimilarArrays{Float64,2,2}, rand(2,3,4,5))
8992
end
9093

9194

@@ -115,6 +118,8 @@ using StatsBase: cov2cor
115118

116119
r = @inferred(rand(5,5))
117120
@test @inferred(flatview(ArrayOfSimilarVectors(r))) == r
121+
122+
test_rrule(ArrayOfSimilarArrays{Float64,2,2}, [rand(2,3) for i in 1:5, j in 1:6])
118123
end
119124

120125
@testset "add remove" begin

0 commit comments

Comments
 (0)