Skip to content

Commit df0e604

Browse files
committed
Brick back FillArrays tests
1 parent 9add5ec commit df0e604

File tree

1 file changed

+52
-40
lines changed

1 file changed

+52
-40
lines changed

src/kroneckerarray.jl

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function KroneckerArray(a::AbstractArray, b::AbstractArray)
3131
)
3232
end
3333
elt = promote_type(eltype(a), eltype(b))
34-
return KroneckerArray(_convert(AbstractArray{elt}, a), _convert(AbstractArray{elt}, b))
34+
return _convert(AbstractArray{elt}, a) _convert(AbstractArray{elt}, b)
3535
end
3636
const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B}
3737
const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B}
@@ -70,63 +70,69 @@ function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where
7070
return _convert(A, arg1(a)) _convert(B, arg2(a))
7171
end
7272

73-
# Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`.
74-
function _similar(a::AbstractArray, elt::Type, axs::Tuple)
75-
return similar(a, elt, axs)
76-
end
77-
function _similar(a::AbstractArray, ax::Tuple)
78-
return _similar(a, eltype(a), ax)
79-
end
80-
function _similar(a::AbstractArray, elt::Type)
81-
return _similar(a, elt, axes(a))
82-
end
83-
function _similar(a::AbstractArray)
84-
return _similar(a, eltype(a), axes(a))
85-
end
86-
function _similar(arrayt::Type{<:AbstractArray}, axs::Tuple)
87-
return similar(arrayt, axs)
88-
end
89-
90-
function Base.similar(
91-
a::AbstractArray,
92-
elt::Type,
93-
axs::Tuple{
94-
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
95-
},
96-
)
97-
return _similar(a, elt, map(arg1, axs)) _similar(a, elt, map(arg2, axs))
98-
end
9973
function Base.similar(
10074
a::KroneckerArray,
10175
elt::Type,
10276
axs::Tuple{
10377
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
10478
},
10579
)
106-
return _similar(arg1(a), elt, map(arg1, axs)) _similar(arg2(a), elt, map(arg2, axs))
80+
return similar(arg1(a), elt, map(arg1, axs)) similar(arg2(a), elt, map(arg2, axs))
81+
end
82+
function Base.similar(a::KroneckerArray, elt::Type)
83+
# TODO: Is this a good definition?
84+
return if isactive(arg1(a)) == isactive(arg2(a))
85+
similar(arg1(a), elt) similar(arg2(a), elt)
86+
elseif isactive(arg1(a))
87+
similar(arg1(a), elt) elt.(arg2(a))
88+
elseif isactive(arg2(a))
89+
elt.(arg1(a)) similar(arg2(a), elt)
90+
end
91+
end
92+
function Base.similar(a::KroneckerArray)
93+
# TODO: Is this a good definition?
94+
return if isactive(arg1(a)) == isactive(arg2(a))
95+
similar(arg1(a)) similar(arg2(a))
96+
elseif isactive(arg1(a))
97+
similar(arg1(a)) arg2(a)
98+
elseif isactive(arg2(a))
99+
arg1(a) similar(arg2(a))
100+
end
107101
end
102+
108103
function Base.similar(
109-
arrayt::Type{<:AbstractArray},
104+
a::AbstractArray,
105+
elt::Type,
110106
axs::Tuple{
111107
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
112108
},
113109
)
114-
return _similar(arrayt, map(arg1, axs)) _similar(arrayt, map(arg2, axs))
110+
return similar(a, elt, map(arg1, axs)) similar(a, elt, map(arg2, axs))
115111
end
112+
116113
function Base.similar(
117114
arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}},
118115
axs::Tuple{
119116
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
120117
},
121118
) where {A,B}
122-
return _similar(A, map(arg1, axs)) _similar(B, map(arg2, axs))
119+
return similar(A, map(arg1, axs)) similar(B, map(arg2, axs))
123120
end
124121
function Base.similar(
125122
::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, sz::Tuple{Int,Vararg{Int}}
126123
) where {A,B}
127124
return similar(promote_type(A, B), sz)
128125
end
129126

127+
function Base.similar(
128+
arrayt::Type{<:AbstractArray},
129+
axs::Tuple{
130+
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
131+
},
132+
)
133+
return similar(arrayt, map(arg1, axs)) similar(arrayt, map(arg2, axs))
134+
end
135+
130136
function Base.permutedims(a::KroneckerArray, perm)
131137
return permutedims(arg1(a), perm) permutedims(arg2(a), perm)
132138
end
@@ -168,7 +174,17 @@ kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b)
168174
# Eagerly collect arguments to make more general on GPU.
169175
Base.collect(a::KroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a)))
170176

171-
Base.zero(a::KroneckerArray) = zero(arg1(a)) zero(arg2(a))
177+
function Base.zero(a::KroneckerArray)
178+
return if isactive(arg1(a)) == isactive(arg2(a))
179+
# TODO: Maybe this should zero both arguments?
180+
# This is how `a * false` would behave.
181+
arg1(a) zero(arg2(a))
182+
elseif isactive(arg1(a))
183+
zero(arg1(a)) arg2(a)
184+
elseif isactive(arg2(a))
185+
arg1(a) zero(arg2(a))
186+
end
187+
end
172188

173189
using DerivableInterfaces: DerivableInterfaces, zero!
174190
function DerivableInterfaces.zero!(a::KroneckerArray)
@@ -240,19 +256,15 @@ function Base.to_indices(
240256
return I1 I2
241257
end
242258

243-
# Allow customizing for `FillArrays.Eye`.
244-
_getindex(a::AbstractArray, I...) = a[I...]
245259
function Base.getindex(
246260
a::KroneckerArray{<:Any,N}, I::Vararg{Union{CartesianPair,CartesianProduct},N}
247261
) where {N}
248262
I′ = to_indices(a, I)
249-
return _getindex(arg1(a), arg1.(I′)...) _getindex(arg2(a), arg2.(I′)...)
263+
return arg1(a)[arg1.(I′)...] arg2(a)[arg2.(I′)...]
250264
end
251265
# Fix ambigiuity error.
252266
Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[]
253267

254-
# Allow customizing for `FillArrays.Eye`.
255-
_view(a::AbstractArray, I...) = view(a, I...)
256268
arg1(::Colon) = (:)
257269
arg2(::Colon) = (:)
258270
arg1(::Base.Slice) = (:)
@@ -261,13 +273,13 @@ function Base.view(
261273
a::KroneckerArray{<:Any,N},
262274
I::Vararg{Union{CartesianProduct,CartesianProductUnitRange,Base.Slice,Colon},N},
263275
) where {N}
264-
return _view(arg1(a), arg1.(I)...) _view(arg2(a), arg2.(I)...)
276+
return view(arg1(a), arg1.(I)...) view(arg2(a), arg2.(I)...)
265277
end
266278
function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N}
267-
return _view(arg1(a), arg1.(I)...) _view(arg2(a), arg2.(I)...)
279+
return view(arg1(a), arg1.(I)...) view(arg2(a), arg2.(I)...)
268280
end
269281
# Fix ambigiuity error.
270-
Base.view(a::KroneckerArray{<:Any,0}) = _view(arg1(a)) * _view(arg2(a))
282+
Base.view(a::KroneckerArray{<:Any,0}) = view(arg1(a)) view(arg2(a))
271283

272284
function Base.:(==)(a::KroneckerArray, b::KroneckerArray)
273285
return arg1(a) == arg1(b) && arg2(a) == arg2(b)

0 commit comments

Comments
 (0)