Skip to content

Commit b62416f

Browse files
authored
Start BlockSparseArraysExt and testing (#16)
1 parent a6336d1 commit b62416f

File tree

6 files changed

+192
-8
lines changed

6 files changed

+192
-8
lines changed

Project.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.10"
4+
version = "0.1.11"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
89
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
910
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1011
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1314

15+
[weakdeps]
16+
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
17+
18+
[extensions]
19+
KroneckerArraysBlockSparseArraysExt = "BlockSparseArrays"
20+
1421
[compat]
22+
Adapt = "4.3.0"
23+
BlockSparseArrays = "0.7.9"
1524
DerivableInterfaces = "0.5.0"
1625
DiagonalArrays = "0.3.5"
1726
FillArrays = "1.13.0"
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module KroneckerArraysBlockSparseArraysExt
2+
3+
using BlockSparseArrays: BlockSparseArrays, blockrange
4+
using KroneckerArrays: CartesianProduct, cartesianrange
5+
6+
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct})
7+
return blockrange(map(cartesianrange, bs))
8+
end
9+
10+
end

src/KroneckerArrays.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ end
9292
const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B}
9393
const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B}
9494

95+
using Adapt: Adapt, adapt
96+
Adapt.adapt_structure(to, a::KroneckerArray) = adapt(to, a.a) adapt(to, a.b)
97+
9598
function Base.copy(a::KroneckerArray)
9699
return copy(a.a) copy(a.b)
97100
end
@@ -930,6 +933,11 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
930933
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
931934
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
932935

936+
using Adapt: Adapt, adapt
937+
Adapt.adapt_structure(to, a::SquareEyeKronecker) = a.a adapt(to, a.b)
938+
Adapt.adapt_structure(to, a::KroneckerSquareEye) = adapt(to, a.a) a.b
939+
Adapt.adapt_structure(to, a::SquareEyeSquareEye) = adapt(to, a.a) a.b
940+
933941
# Special case of similar for `SquareEye ⊗ A` and `A ⊗ SquareEye`.
934942
function Base.similar(
935943
a::SquareEyeKronecker,
@@ -970,22 +978,22 @@ function Base.similar(
970978
end
971979

972980
function Base.similar(
973-
arrayt::Type{<:SquareEyeKronecker{<:Any,<:Any,A}},
981+
arrayt::Type{<:SquareEyeKronecker{T,A,B}},
974982
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
975-
) where {A}
983+
) where {T,A<:SquareEye{T},B}
976984
ax_a = map(ax -> ax.product.a, axs)
977985
ax_b = map(ax -> ax.product.b, axs)
978986
eye_ax_a = (only(unique(ax_a)),)
979-
return Eye{eltype(arrayt)}(eye_ax_a) similar(A, ax_b)
987+
return Eye{T}(eye_ax_a) similar(B, ax_b)
980988
end
981989
function Base.similar(
982-
arrayt::Type{<:KroneckerSquareEye{<:Any,A}},
990+
arrayt::Type{<:KroneckerSquareEye{T,A,B}},
983991
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
984-
) where {A}
992+
) where {T,A,B<:SquareEye{T}}
985993
ax_a = map(ax -> ax.product.a, axs)
986994
ax_b = map(ax -> ax.product.b, axs)
987995
eye_ax_b = (only(unique(ax_b)),)
988-
return similar(A, ax_a) Eye{eltype(arrayt)}(eye_ax_b)
996+
return similar(A, ax_a) Eye{T}(eye_ax_b)
989997
end
990998
function Base.similar(
991999
arrayt::Type{<:SquareEyeSquareEye}, axs::NTuple{2,CartesianProductUnitRange{<:Integer}}

test/Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
[deps]
2+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
23
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
5+
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
36
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
47
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
8+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
59
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
610
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
711
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
@@ -12,9 +16,13 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1216
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1317

1418
[compat]
19+
Adapt = "4"
1520
Aqua = "0.8"
21+
BlockArrays = "1.6"
22+
BlockSparseArrays = "0.7"
1623
DerivableInterfaces = "0.5"
1724
FillArrays = "1"
25+
JLArrays = "0.2"
1826
KroneckerArrays = "0.1"
1927
LinearAlgebra = "1.10"
2028
MatrixAlgebraKit = "0.2"

test/test_basics.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
using Adapt: adapt
12
using Base.Broadcast: BroadcastStyle, Broadcasted, broadcasted
23
using DerivableInterfaces: zero!
34
using FillArrays: Eye
5+
using JLArrays: JLArray
46
using KroneckerArrays:
57
KroneckerArrays,
68
KroneckerArray,
@@ -17,7 +19,7 @@ using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd,
1719
using StableRNGs: StableRNG
1820
using Test: @test, @test_broken, @test_throws, @testset
1921

20-
const elts = (Float32, Float64, ComplexF32, ComplexF64)
22+
elts = (Float32, Float64, ComplexF32, ComplexF64)
2123
@testset "KroneckerArrays (eltype=$elt)" for elt in elts
2224
p = [1, 2] × [3, 4, 5]
2325
@test length(p) == 6
@@ -78,6 +80,7 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
7880
@test norm(a) norm(collect(a))
7981

8082
# Broadcasting
83+
a = randn(elt, 2, 2) randn(elt, 3, 3)
8184
style = KroneckerStyle(BroadcastStyle(typeof(a.a)), BroadcastStyle(typeof(a.b)))
8285
@test BroadcastStyle(typeof(a)) === style
8386
@test_throws "not supported" sin.(a)
@@ -94,6 +97,7 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
9497
@test_broken copy(bc)
9598

9699
# Mapping
100+
a = randn(elt, 2, 2) randn(elt, 3, 3)
97101
@test_throws "not supported" map(sin, a)
98102
@test_broken map(Base.Fix1(*, 2), a)
99103
a′ = similar(a)
@@ -120,6 +124,7 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
120124
map!(conj, a′, a)
121125
@test collect(a′) conj(collect(a))
122126

127+
a = randn(elt, 2, 2) randn(elt, 3, 3)
123128
if elt <: Real
124129
@test real(a) == a
125130
else
@@ -131,6 +136,15 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
131136
@test_throws ArgumentError imag(a)
132137
end
133138

139+
# Adapt
140+
a = randn(elt, 2, 2) randn(elt, 3, 3)
141+
a′ = adapt(JLArray, a)
142+
@test a′ isa KroneckerArray{elt,2,JLArray{elt,2},JLArray{elt,2}}
143+
@test a′.a isa JLArray{elt,2}
144+
@test a′.b isa JLArray{elt,2}
145+
@test Array(a′.a) == a.a
146+
@test Array(a′.b) == a.b
147+
134148
a = randn(elt, 2, 2, 2) randn(elt, 3, 3, 3)
135149
@test collect(a) kron_nd(a.a, a.b)
136150
@test a[1 × 1, 1 × 1, 1 × 1] == a.a[1, 1, 1] * a.b[1, 1, 1]

test/test_blocksparsearrays.jl

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
using Adapt: adapt
2+
using BlockArrays: Block, BlockRange
3+
using BlockSparseArrays:
4+
BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype
5+
using FillArrays: Eye, SquareEye
6+
using JLArrays: JLArray
7+
using KroneckerArrays: KroneckerArray, , ×
8+
using LinearAlgebra: norm
9+
using MatrixAlgebraKit: svd_compact
10+
using Test: @test, @test_broken, @testset
11+
using TestExtras: @constinferred
12+
13+
elts = (Float32, Float64, ComplexF32)
14+
arrayts = (Array, JLArray)
15+
@testset "BlockSparseArraysExt, KroneckerArray blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in
16+
arrayts,
17+
elt in elts
18+
19+
dev = adapt(arrayt)
20+
r = blockrange([2 × 2, 3 × 3])
21+
d = Dict(
22+
Block(1, 1) => dev(randn(elt, 2, 2) randn(elt, 2, 2)),
23+
Block(2, 2) => dev(randn(elt, 3, 3) randn(elt, 3, 3)),
24+
)
25+
a = dev(blocksparse(d, r, r))
26+
@test_broken sprint(show, a)
27+
@test sprint(show, MIME("text/plain"), a) isa String
28+
@test blocktype(a) === valtype(d)
29+
@test a isa BlockSparseMatrix{elt,valtype(d)}
30+
@test a[Block(1, 1)] == dev(d[Block(1, 1)])
31+
@test a[Block(1, 1)] isa valtype(d)
32+
@test a[Block(2, 2)] == dev(d[Block(2, 2)])
33+
@test a[Block(2, 2)] isa valtype(d)
34+
@test iszero(a[Block(2, 1)])
35+
@test a[Block(2, 1)] == dev(zeros(elt, 3, 2) zeros(elt, 3, 2))
36+
@test a[Block(2, 1)] isa valtype(d)
37+
@test iszero(a[Block(1, 2)])
38+
@test a[Block(1, 2)] == dev(zeros(elt, 2, 3) zeros(elt, 2, 3))
39+
@test a[Block(1, 2)] isa valtype(d)
40+
41+
b = a * a
42+
@test typeof(b) === typeof(a)
43+
@test Array(b) Array(a) * Array(a)
44+
45+
b = a + a
46+
@test typeof(b) === typeof(a)
47+
@test Array(b) Array(a) + Array(a)
48+
49+
b = 3a
50+
@test typeof(b) === typeof(a)
51+
@test Array(b) 3Array(a)
52+
53+
b = a / 3
54+
@test typeof(b) === typeof(a)
55+
@test Array(b) Array(a) / 3
56+
57+
@test norm(a) norm(Array(a))
58+
59+
if arrayt == Array
60+
@test Array(inv(a)) inv(Array(a))
61+
else
62+
# Broken for JLArray, it seems like `inv` isn't
63+
# type stable.
64+
@test_broken inv(a)
65+
end
66+
67+
# Broken operations
68+
@test_broken exp(a)
69+
@test_broken svd_compact(a)
70+
@test_broken a[Block.(1:2), Block(2)]
71+
end
72+
73+
@testset "BlockSparseArraysExt, SquareEyeKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in
74+
arrayts,
75+
elt in elts
76+
77+
if arrayt == JLArray
78+
# TODO: Collecting to `Array` is broken for GPU arrays so a lot of tests
79+
# are broken, look into fixing that.
80+
continue
81+
end
82+
83+
dev = adapt(arrayt)
84+
r = blockrange([2 × 2, 3 × 3])
85+
d = Dict(
86+
Block(1, 1) => Eye{elt}(2, 2) randn(elt, 2, 2),
87+
Block(2, 2) => Eye{elt}(3, 3) randn(elt, 3, 3),
88+
)
89+
a = dev(blocksparse(d, r, r))
90+
@test_broken sprint(show, a)
91+
@test sprint(show, MIME("text/plain"), a) isa String
92+
@test_broken blocktype(a) === valtype(d)
93+
@test_broken a isa BlockSparseMatrix{elt,valtype(d)}
94+
@test a[Block(1, 1)] == dev(d[Block(1, 1)])
95+
@test_broken a[Block(1, 1)] isa valtype(d)
96+
@test a[Block(2, 2)] == dev(d[Block(2, 2)])
97+
@test_broken a[Block(2, 2)] isa valtype(d)
98+
@test iszero(a[Block(2, 1)])
99+
@test a[Block(2, 1)] == dev(zeros(elt, 3, 2) zeros(elt, 3, 2))
100+
@test_broken a[Block(2, 1)] isa valtype(d)
101+
@test iszero(a[Block(1, 2)])
102+
@test a[Block(1, 2)] == dev(zeros(elt, 2, 3) zeros(elt, 2, 3))
103+
@test_broken a[Block(1, 2)] isa valtype(d)
104+
105+
b = a * a
106+
@test typeof(b) === typeof(a)
107+
@test Array(b) Array(a) * Array(a)
108+
109+
b = a + a
110+
@test typeof(b) === typeof(a)
111+
@test Array(b) Array(a) + Array(a)
112+
113+
b = 3a
114+
@test typeof(b) === typeof(a)
115+
@test Array(b) 3Array(a)
116+
117+
b = a / 3
118+
@test typeof(b) === typeof(a)
119+
@test Array(b) Array(a) / 3
120+
121+
@test norm(a) norm(Array(a))
122+
123+
if arrayt == Array
124+
@test Array(inv(a)) inv(Array(a))
125+
else
126+
# Broken for JLArray, it seems like `inv` isn't
127+
# type stable.
128+
@test_broken inv(a)
129+
end
130+
131+
# Broken operations
132+
# @test_broken exp(a)
133+
@test_broken svd_compact(a)
134+
@test_broken a[Block.(1:2), Block(2)]
135+
end

0 commit comments

Comments
 (0)