Skip to content

Commit 766b435

Browse files
committed
More similar definitions and tests
1 parent a3ed859 commit 766b435

File tree

3 files changed

+149
-8
lines changed

3 files changed

+149
-8
lines changed

src/KroneckerArrays.jl

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ function Base.map!(f::typeof(-), dest::EyeKronecker, a::EyeKronecker)
717717
return dest
718718
end
719719
function Base.map!(f::typeof(-), dest::KroneckerEye, a::KroneckerEye)
720-
map!(f, dest.a, a.a, b.a)
720+
map!(f, dest.a, a.a)
721721
return dest
722722
end
723723
function Base.map!(f::typeof(-), dest::EyeEye, a::EyeEye)
@@ -932,30 +932,65 @@ const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,
932932

933933
# Special case of similar for `SquareEye ⊗ A` and `A ⊗ SquareEye`.
934934
function Base.similar(
935-
arrayt::Type{<:SquareEyeKronecker{<:Any,<:Any,A}},
935+
a::SquareEyeKronecker,
936936
elt::Type,
937+
axs::Tuple{
938+
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
939+
},
940+
)
941+
ax_a = map(ax -> ax.product.a, axs)
942+
ax_b = map(ax -> ax.product.b, axs)
943+
eye_ax_a = (only(unique(ax_a)),)
944+
return Eye{elt}(eye_ax_a) similar(a.b, elt, ax_b)
945+
end
946+
function Base.similar(
947+
a::KroneckerSquareEye,
948+
elt::Type,
949+
axs::Tuple{
950+
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
951+
},
952+
)
953+
ax_a = map(ax -> ax.product.a, axs)
954+
ax_b = map(ax -> ax.product.b, axs)
955+
eye_ax_b = (only(unique(ax_b)),)
956+
return similar(a.a, elt, ax_a) Eye{elt}(eye_ax_b)
957+
end
958+
function Base.similar(
959+
a::SquareEyeSquareEye,
960+
elt::Type,
961+
axs::Tuple{
962+
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
963+
},
964+
)
965+
ax_a = map(ax -> ax.product.a, axs)
966+
ax_b = map(ax -> ax.product.b, axs)
967+
eye_ax_a = (only(unique(ax_a)),)
968+
eye_ax_b = (only(unique(ax_b)),)
969+
return Eye{elt}(eye_ax_a) Eye{elt}(eye_ax_b)
970+
end
971+
972+
function Base.similar(
973+
arrayt::Type{<:SquareEyeKronecker{<:Any,<:Any,A}},
937974
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
938975
) where {A}
939976
ax_a = map(ax -> ax.product.a, axs)
940977
ax_b = map(ax -> ax.product.b, axs)
941978
eye_ax_a = (only(unique(ax_a)),)
942-
return Eye{elt}(eye_ax_a) similar(A, elt, ax_b)
979+
return Eye{eltype(arrayt)}(eye_ax_a) similar(A, ax_b)
943980
end
944981
function Base.similar(
945982
arrayt::Type{<:KroneckerSquareEye{<:Any,A}},
946-
elt::Type,
947983
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
948984
) where {A}
949985
ax_a = map(ax -> ax.product.a, axs)
950986
ax_b = map(ax -> ax.product.b, axs)
951987
eye_ax_b = (only(unique(ax_b)),)
952-
return similar(A, elt, ax_a) Eye{elt}(eye_ax_b)
988+
return similar(A, ax_a) Eye{eltype(arrayt)}(eye_ax_b)
953989
end
954990
function Base.similar(
955-
arrayt::Type{<:SquareEyeSquareEye},
956-
elt::Type,
957-
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
991+
arrayt::Type{<:SquareEyeSquareEye}, axs::NTuple{2,CartesianProductUnitRange{<:Integer}}
958992
)
993+
elt = eltype(arrayt)
959994
ax_a = map(ax -> ax.product.a, axs)
960995
ax_b = map(ax -> ax.product.b, axs)
961996
eye_ax_a = (only(unique(ax_a)),)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
34
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
45
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
56
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

test/test_basics.jl

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Base.Broadcast: BroadcastStyle, Broadcasted, broadcasted
2+
using DerivableInterfaces: zero!
23
using FillArrays: Eye
34
using KroneckerArrays:
45
KroneckerArrays,
@@ -187,6 +188,110 @@ end
187188
@test 2a == (2a.a) Eye(2)
188189
@test a * a == (a.a * a.a) Eye(2)
189190

191+
# similar
192+
a = Eye(2) randn(3, 3)
193+
for a′ in (
194+
similar(a),
195+
similar(a, eltype(a)),
196+
similar(a, axes(a)),
197+
similar(a, eltype(a), axes(a)),
198+
similar(typeof(a), axes(a)),
199+
)
200+
@test size(a′) == (6, 6)
201+
@test a′ isa KroneckerArray{eltype(a),ndims(a),typeof(a.a),typeof(a.b)}
202+
@test a′.a === a.a
203+
end
204+
205+
a = Eye(2) randn(3, 3)
206+
for args in ((Float32,), (Float32, axes(a)))
207+
a′ = similar(a, args...)
208+
@test size(a′) == (6, 6)
209+
@test a′ isa KroneckerArray{Float32,ndims(a)}
210+
@test a′.a === Eye{Float32}(2)
211+
end
212+
213+
a = randn(3, 3) Eye(2)
214+
for a′ in (
215+
similar(a),
216+
similar(a, eltype(a)),
217+
similar(a, axes(a)),
218+
similar(a, eltype(a), axes(a)),
219+
similar(typeof(a), axes(a)),
220+
)
221+
@test size(a′) == (6, 6)
222+
@test a′ isa KroneckerArray{eltype(a),ndims(a),typeof(a.a),typeof(a.b)}
223+
@test a′.b === a.b
224+
end
225+
226+
a = randn(3, 3) Eye(2)
227+
for args in ((Float32,), (Float32, axes(a)))
228+
a′ = similar(a, args...)
229+
@test size(a′) == (6, 6)
230+
@test a′ isa KroneckerArray{Float32,ndims(a)}
231+
@test a′.b === Eye{Float32}(2)
232+
end
233+
234+
a = Eye(3) Eye(2)
235+
for a′ in (
236+
similar(a),
237+
similar(a, eltype(a)),
238+
similar(a, axes(a)),
239+
similar(a, eltype(a), axes(a)),
240+
similar(typeof(a), axes(a)),
241+
)
242+
@test size(a′) == (6, 6)
243+
@test a′ isa KroneckerArray{eltype(a),ndims(a),typeof(a.a),typeof(a.b)}
244+
@test a′.a === a.a
245+
@test a′.b === a.b
246+
end
247+
248+
a = Eye(3) Eye(2)
249+
for args in ((Float32,), (Float32, axes(a)))
250+
a′ = similar(a, args...)
251+
@test size(a′) == (6, 6)
252+
@test a′ isa KroneckerArray{Float32,ndims(a)}
253+
@test a′.a === Eye{Float32}(3)
254+
@test a′.b === Eye{Float32}(2)
255+
end
256+
257+
# DerivableInterfaces.zero!
258+
for a in (Eye(2) randn(3, 3), randn(3, 3) Eye(2))
259+
zero!(a)
260+
@test iszero(a)
261+
end
262+
a = Eye(3) Eye(2)
263+
@test_throws ArgumentError zero!(a)
264+
265+
# map!(+, ...)
266+
for a in (Eye(2) randn(3, 3), randn(3, 3) Eye(2))
267+
a′ = similar(a)
268+
map!(+, a′, a, a)
269+
@test collect(a′) 2 * collect(a)
270+
end
271+
a = Eye(3) Eye(2)
272+
a′ = similar(a)
273+
@test_throws ErrorException map!(+, a′, a, a)
274+
275+
# map!(-, ...)
276+
for a in (Eye(2) randn(3, 3), randn(3, 3) Eye(2))
277+
a′ = similar(a)
278+
map!(-, a′, a, a)
279+
@test norm(collect(a′)) 0
280+
end
281+
a = Eye(3) Eye(2)
282+
a′ = similar(a)
283+
@test_throws ErrorException map!(-, a′, a, a)
284+
285+
# map!(-, b, a)
286+
for a in (Eye(2) randn(3, 3), randn(3, 3) Eye(2))
287+
a′ = similar(a)
288+
map!(-, a′, a)
289+
@test collect(a′) -collect(a)
290+
end
291+
a = Eye(3) Eye(2)
292+
a′ = similar(a)
293+
@test_throws ErrorException map!(-, a′, a)
294+
190295
# Eye ⊗ A
191296
rng = StableRNG(123)
192297
a = Eye(2) randn(rng, 3, 3)

0 commit comments

Comments
 (0)