Skip to content

Commit a3ed859

Browse files
committed
Add more tests
1 parent 92e5428 commit a3ed859

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

src/KroneckerArrays.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
302302
function KroneckerStyle{N}(a::BroadcastStyle, b::BroadcastStyle) where {N}
303303
return KroneckerStyle{N,a,b}()
304304
end
305+
function KroneckerStyle(a::AbstractArrayStyle{N}, b::AbstractArrayStyle{N}) where {N}
306+
return KroneckerStyle{N}(a, b)
307+
end
305308
function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M}
306309
return KroneckerStyle{M,typeof(A)(v),typeof(B)(v)}()
307310
end
@@ -316,8 +319,8 @@ end
316319
function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type) where {N,A,B}
317320
ax_a = map(ax -> ax.product.a, axes(bc))
318321
ax_b = map(ax -> ax.product.b, axes(bc))
319-
bc_a = Broadcasted(A, ax_a)
320-
bc_b = Broadcasted(B, ax_b)
322+
bc_a = Broadcasted(A, nothing, (), ax_a)
323+
bc_b = Broadcasted(B, nothing, (), ax_b)
321324
a = similar(bc_a, elt)
322325
b = similar(bc_b, elt)
323326
return a b

test/test_basics.jl

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
using Base.Broadcast: BroadcastStyle, Broadcasted, broadcasted
12
using FillArrays: Eye
23
using KroneckerArrays:
34
KroneckerArrays,
5+
KroneckerArray,
6+
KroneckerStyle,
47
CartesianProductUnitRange,
58
,
69
×,
@@ -9,7 +12,7 @@ using KroneckerArrays:
912
diagonal,
1013
kron_nd,
1114
unproduct
12-
using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, pinv, qr, svd, svdvals, tr
15+
using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd, svdvals, tr
1316
using StableRNGs: StableRNG
1417
using Test: @test, @test_broken, @test_throws, @testset
1518

@@ -41,8 +44,10 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
4144
a = randn(elt, 2, 2) randn(elt, 3, 3)
4245
b = randn(elt, 2, 2) randn(elt, 3, 3)
4346
c = a.a b.b
47+
@test a isa KroneckerArray{elt,2,typeof(a.a),typeof(a.b)}
4448
@test similar(typeof(a), (2, 3)) isa Matrix{elt}
4549
@test size(similar(typeof(a), (2, 3))) == (2, 3)
50+
@test isreal(a) == (elt <: Real)
4651
@test a[1 × 1, 1 × 1] == a.a[1, 1] * a.b[1, 1]
4752
@test a[1 × 3, 2 × 1] == a.a[1, 2] * a.b[3, 1]
4853
@test a[1 × (2:3), 2 × 1] == a.a[1, 2] * a.b[2:3, 1]
@@ -60,6 +65,7 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
6065
@test collect(-a) == -collect(a)
6166
@test collect(3 * a) 3 * collect(a)
6267
@test collect(a * 3) collect(a) * 3
68+
@test collect(a / 3) collect(a) / 3
6369
@test a + a == 2a
6470
@test iszero(a - a)
6571
@test collect(a + c) collect(a) + collect(c)
@@ -68,6 +74,61 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
6874
@test collect(f(a)) f(collect(a))
6975
end
7076
@test tr(a) tr(collect(a))
77+
@test norm(a) norm(collect(a))
78+
79+
# Broadcasting
80+
style = KroneckerStyle(BroadcastStyle(typeof(a.a)), BroadcastStyle(typeof(a.b)))
81+
@test BroadcastStyle(typeof(a)) === style
82+
@test_throws "not supported" sin.(a)
83+
a′ = similar(a)
84+
@test_throws "not supported" a′ .= sin.(a)
85+
a′ = similar(a)
86+
@test_broken a′ .= 2 .* a
87+
bc = broadcasted(+, a, a)
88+
@test bc.style === style
89+
@test similar(bc, elt) isa KroneckerArray{elt,2,typeof(a.a),typeof(a.b)}
90+
@test_broken copy(bc)
91+
bc = broadcasted(*, 2, a)
92+
@test bc.style === style
93+
@test_broken copy(bc)
94+
95+
# Mapping
96+
@test_throws "not supported" map(sin, a)
97+
@test_broken map(Base.Fix1(*, 2), a)
98+
a′ = similar(a)
99+
@test_throws "not supported" map!(sin, a′, a)
100+
a′ = similar(a)
101+
map!(identity, a′, a)
102+
@test collect(a′) collect(a)
103+
a′ = similar(a)
104+
map!(+, a′, a, a)
105+
@test collect(a′) 2 * collect(a)
106+
a′ = similar(a)
107+
map!(-, a′, a, a)
108+
@test norm(collect(a′)) 0
109+
a′ = similar(a)
110+
map!(Base.Fix1(*, 2), a′, a)
111+
@test collect(a′) 2 * collect(a)
112+
a′ = similar(a)
113+
map!(Base.Fix2(*, 2), a′, a)
114+
@test collect(a′) 2 * collect(a)
115+
a′ = similar(a)
116+
map!(Base.Fix2(/, 2), a′, a)
117+
@test collect(a′) collect(a) / 2
118+
a′ = similar(a)
119+
map!(conj, a′, a)
120+
@test collect(a′) conj(collect(a))
121+
122+
if elt <: Real
123+
@test real(a) == a
124+
else
125+
@test_throws ArgumentError real(a)
126+
end
127+
if elt <: Real
128+
@test iszero(imag(a))
129+
else
130+
@test_throws ArgumentError imag(a)
131+
end
71132

72133
a = randn(elt, 2, 2, 2) randn(elt, 3, 3, 3)
73134
@test collect(a) kron_nd(a.a, a.b)

0 commit comments

Comments
 (0)