@@ -70,29 +70,40 @@ struct ArrayOfSimilarArrays{
70
70
data:: P
71
71
72
72
function ArrayOfSimilarArrays {T,M,N} (flat_data:: AbstractArray{U,L} ) where {T,M,N,L,U}
73
- size_inner, size_outer = split_tuple (size (flat_data), Val {M} ())
74
73
require_ndims (flat_data, _add_vals (Val {M} (), Val {N} ()))
75
74
conv_parent = _convert_elype (T, flat_data)
76
75
P = typeof (conv_parent)
77
76
new {T,M,N,L,P} (conv_parent)
78
77
end
78
+ end
79
79
80
- function ArrayOfSimilarArrays {T,M} (flat_data:: AbstractArray{U,L} ) where {T,M,L,U}
81
- size_inner, size_outer = split_tuple (size (flat_data), Val {M} ())
82
- N = length (size_outer)
83
- conv_parent = _convert_elype (T, flat_data)
84
- P = typeof (conv_parent)
85
- new {T,M,N,L,P} (conv_parent)
86
- end
80
+ function ArrayOfSimilarArrays {T,M} (flat_data:: AbstractArray{U,L} ) where {T,M,L,U}
81
+ _, size_outer = split_tuple (size (flat_data), Val {M} ())
82
+ N = length (size_outer)
83
+ ArrayOfSimilarArrays {T,M,N} (flat_data)
87
84
end
88
85
89
86
export ArrayOfSimilarArrays
90
87
88
+ function _aosa_ctor_fromflat_pullback (ΔΩ)
89
+ NoTangent (), flatview (convert (ArrayOfSimilarArrays, unthunk (ΔΩ)))
90
+ end
91
+
92
+ function ChainRulesCore. rrule (:: Type{ArrayOfSimilarArrays{T,M,N}} , flat_data:: AbstractArray{U,L} ) where {T,M,N,L,U}
93
+ return ArrayOfSimilarArrays {T,M,N} (flat_data), _aosa_ctor_fromflat_pullback
94
+ end
95
+
91
96
function ArrayOfSimilarArrays {T,M,N} (A:: AbstractArray{<:AbstractArray{U,M},N} ) where {T,M,N,U}
92
97
B = ArrayOfSimilarArrays {T,M,N} (Array {T} (undef, innersize (A)... , size (A)... ))
93
98
copyto! (B, A)
94
99
end
95
100
101
+ _aosa_ctor_fromnested_pullback (ΔΩ) = NoTangent (), ΔΩ
102
+
103
+ function ChainRulesCore. rrule (:: Type{ArrayOfSimilarArrays{T,M,N}} , A:: AbstractArray{<:AbstractArray{U,M},N} ) where {T,M,N,U}
104
+ return ArrayOfSimilarArrays {T,M,N} (A), _aosa_ctor_fromnested_pullback
105
+ end
106
+
96
107
ArrayOfSimilarArrays {T} (A:: AbstractArray{<:AbstractArray{U,M},N} ) where {T,M,N,U} =
97
108
ArrayOfSimilarArrays {T,M,N} (A)
98
109
0 commit comments