Skip to content

Commit ca8be69

Browse files
authored
Block sparse SVD (#19)
1 parent 77f2331 commit ca8be69

File tree

6 files changed

+141
-21
lines changed

6 files changed

+141
-21
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.13"
4+
version = "0.1.14"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -13,14 +13,16 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1414

1515
[weakdeps]
16+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1617
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
1718

1819
[extensions]
19-
KroneckerArraysBlockSparseArraysExt = "BlockSparseArrays"
20+
KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
2021

2122
[compat]
2223
Adapt = "4.3.0"
23-
BlockSparseArrays = "0.7.9"
24+
BlockArrays = "1.6"
25+
BlockSparseArrays = "0.7.13"
2426
DerivableInterfaces = "0.5.0"
2527
DiagonalArrays = "0.3.5"
2628
FillArrays = "1.13.0"

ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,84 @@ module KroneckerArraysBlockSparseArraysExt
22

33
using BlockSparseArrays: BlockSparseArrays, blockrange
44
using KroneckerArrays: CartesianProduct, cartesianrange
5-
65
function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct})
76
return blockrange(map(cartesianrange, bs))
87
end
98

9+
using BlockArrays: AbstractBlockedUnitRange
10+
using BlockSparseArrays: Block, GetUnstoredBlock, eachblockaxis, mortar_axis
11+
using DerivableInterfaces: zero!
12+
using FillArrays: Eye
13+
using KroneckerArrays:
14+
KroneckerArrays,
15+
EyeEye,
16+
EyeKronecker,
17+
KroneckerEye,
18+
KroneckerMatrix,
19+
,
20+
arg1,
21+
arg2,
22+
_similar
23+
24+
function KroneckerArrays.arg1(r::AbstractBlockedUnitRange)
25+
return mortar_axis(arg2.(eachblockaxis(r)))
26+
end
27+
function KroneckerArrays.arg2(r::AbstractBlockedUnitRange)
28+
return mortar_axis(arg2.(eachblockaxis(r)))
29+
end
30+
31+
function block_axes(
32+
ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Vararg{Block{1},N}
33+
) where {N}
34+
return ntuple(N) do d
35+
return only(axes(ax[d][I[d]]))
36+
end
37+
end
38+
function block_axes(ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N}) where {N}
39+
return block_axes(ax, Tuple(I)...)
40+
end
41+
42+
function (f::GetUnstoredBlock)(
43+
::Type{<:AbstractMatrix{KroneckerMatrix{T,A,B}}}, I::Vararg{Int,2}
44+
) where {T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}}
45+
ax_a = arg1.(f.axes)
46+
f_a = GetUnstoredBlock(ax_a)
47+
a = f_a(AbstractMatrix{A}, I...)
48+
49+
ax_b = arg2.(f.axes)
50+
f_b = GetUnstoredBlock(ax_b)
51+
b = f_b(AbstractMatrix{B}, I...)
52+
53+
return a b
54+
end
55+
function (f::GetUnstoredBlock)(
56+
::Type{<:AbstractMatrix{EyeKronecker{T,A,B}}}, I::Vararg{Int,2}
57+
) where {T,A<:Eye{T},B<:AbstractMatrix{T}}
58+
block_ax_a = arg1.(block_axes(f.axes, Block(I)))
59+
a = _similar(A, block_ax_a)
60+
61+
ax_b = arg2.(f.axes)
62+
f_b = GetUnstoredBlock(ax_b)
63+
b = f_b(AbstractMatrix{B}, I...)
64+
65+
return a b
66+
end
67+
function (f::GetUnstoredBlock)(
68+
::Type{<:AbstractMatrix{KroneckerEye{T,A,B}}}, I::Vararg{Int,2}
69+
) where {T,A<:AbstractMatrix{T},B<:Eye{T}}
70+
ax_a = arg1.(f.axes)
71+
f_a = GetUnstoredBlock(ax_a)
72+
a = f_a(AbstractMatrix{A}, I...)
73+
74+
block_ax_b = arg2.(block_axes(f.axes, Block(I)))
75+
b = _similar(B, block_ax_b)
76+
77+
return a b
78+
end
79+
function (f::GetUnstoredBlock)(
80+
::Type{<:AbstractMatrix{EyeEye{T,A,B}}}, I::Vararg{Int,2}
81+
) where {T,A<:Eye{T},B<:Eye{T}}
82+
return error("Not implemented.")
83+
end
84+
1085
end

src/cartesianproduct.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ end
55
arguments(a::CartesianProduct) = (a.a, a.b)
66
arguments(a::CartesianProduct, n::Int) = arguments(a)[n]
77

8+
arg1(a::CartesianProduct) = a.a
9+
arg2(a::CartesianProduct) = a.b
10+
811
function Base.show(io::IO, a::CartesianProduct)
912
print(io, a.a, " × ", a.b)
1013
return nothing
@@ -32,6 +35,9 @@ Base.last(r::CartesianProductUnitRange) = last(r.range)
3235
cartesianproduct(r::CartesianProductUnitRange) = getfield(r, :product)
3336
unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
3437

38+
arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a))
39+
arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a))
40+
3541
function CartesianProductUnitRange(p::CartesianProduct)
3642
return CartesianProductUnitRange(p, Base.OneTo(length(p)))
3743
end

src/fillarrays/kroneckerarray.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using FillArrays: RectDiagonal, OnesVector
2+
const RectEye{T,V<:OnesVector{T},Axes} = RectDiagonal{T,V,Axes}
3+
14
using FillArrays: Eye
25
const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
36
const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}
@@ -11,6 +14,14 @@ const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,
1114
# Like `adapt` but preserves `Eye`.
1215
_adapt(to, a::Eye) = a
1316

17+
# Allows customizing for `FillArrays.Eye`.
18+
function _convert(::Type{AbstractArray{T}}, a::RectDiagonal) where {T}
19+
_convert(AbstractMatrix{T}, a)
20+
end
21+
function _convert(::Type{AbstractMatrix{T}}, a::RectDiagonal) where {T}
22+
RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a))
23+
end
24+
1425
# Like `similar` but preserves `Eye`.
1526
function _similar(a::AbstractArray, elt::Type, ax::Tuple)
1627
return similar(a, elt, ax)
@@ -124,15 +135,15 @@ for op in (:+, :-)
124135
end
125136
end
126137

127-
function Base.map!(f::typeof(identity), dest::EyeKronecker, a::EyeKronecker)
138+
function Base.map!(f::typeof(identity), dest::EyeKronecker, src::EyeKronecker)
128139
map!(f, dest.b, src.b)
129140
return dest
130141
end
131-
function Base.map!(f::typeof(identity), dest::KroneckerEye, a::KroneckerEye)
142+
function Base.map!(f::typeof(identity), dest::KroneckerEye, src::KroneckerEye)
132143
map!(f, dest.a, src.a)
133144
return dest
134145
end
135-
function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye)
146+
function Base.map!(::typeof(identity), dest::EyeEye, src::EyeEye)
136147
return error("Can't write in-place.")
137148
end
138149
for f in [:+, :-]

src/kroneckerarray.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Allows customizing for `FillArrays.Eye`.
2+
function _convert(A::Type{<:AbstractArray}, a::AbstractArray)
3+
return convert(A, a)
4+
end
5+
16
struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N}
27
a::A
38
b::B
@@ -9,11 +14,14 @@ function KroneckerArray(a::AbstractArray, b::AbstractArray)
914
)
1015
end
1116
elt = promote_type(eltype(a), eltype(b))
12-
return KroneckerArray(convert(AbstractArray{elt}, a), convert(AbstractArray{elt}, b))
17+
return KroneckerArray(_convert(AbstractArray{elt}, a), _convert(AbstractArray{elt}, b))
1318
end
1419
const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B}
1520
const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B}
1621

22+
arg1(a::KroneckerArray) = a.a
23+
arg2(a::KroneckerArray) = a.b
24+
1725
using Adapt: Adapt, adapt
1826
_adapt(to, a::AbstractArray) = adapt(to, a)
1927
Adapt.adapt_structure(to, a::KroneckerArray) = _adapt(to, a.a) _adapt(to, a.b)
@@ -106,7 +114,8 @@ end
106114
kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b)
107115
kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b)
108116

109-
Base.collect(a::KroneckerArray) = kron_nd(a.a, a.b)
117+
# Eagerly collect arguments to make more general on GPU.
118+
Base.collect(a::KroneckerArray) = kron_nd(collect(a.a), collect(a.b))
110119

111120
function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N}
112121
return convert(Array{T,N}, collect(a))

test/test_blocksparsearrays.jl

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,27 +64,29 @@ arrayts = (Array, JLArray)
6464
@test_broken inv(a)
6565
end
6666

67+
if (VERSION v"1.11-" && arrayt === Array && elt <: Complex) ||
68+
(arrayt === Array && elt <: Real)
69+
u, s, v = svd_compact(a)
70+
@test Array(u * s * v) Array(a)
71+
else
72+
# Broken on GPU and for complex, investigate.
73+
@test_broken svd_compact(a)
74+
end
75+
6776
# Broken operations
6877
@test_broken exp(a)
69-
@test_broken svd_compact(a)
7078
@test_broken a[Block.(1:2), Block(2)]
7179
end
7280

7381
@testset "BlockSparseArraysExt, EyeKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in
7482
arrayts,
7583
elt in elts
7684

77-
if arrayt == JLArray
78-
# TODO: Collecting to `Array` is broken for GPU arrays so a lot of tests
79-
# are broken, look into fixing that.
80-
continue
81-
end
82-
8385
dev = adapt(arrayt)
8486
r = @constinferred blockrange([2 × 2, 3 × 3])
8587
d = Dict(
86-
Block(1, 1) => Eye{elt}(2, 2) randn(elt, 2, 2),
87-
Block(2, 2) => Eye{elt}(3, 3) randn(elt, 3, 3),
88+
Block(1, 1) => Eye{elt}(2, 2) dev(randn(elt, 2, 2)),
89+
Block(2, 2) => Eye{elt}(3, 3) dev(randn(elt, 3, 3)),
8890
)
8991
a = @constinferred dev(blocksparse(d, r, r))
9092
@test sprint(show, a) == sprint(show, Array(a))
@@ -126,11 +128,26 @@ end
126128

127129
@test @constinferred(norm(a)) norm(Array(a))
128130

129-
b = @constinferred exp(a)
130-
@test Array(b) exp(Array(a))
131+
if arrayt === Array
132+
b = @constinferred exp(a)
133+
@test Array(b) exp(Array(a))
134+
else
135+
@test_broken exp(a)
136+
end
137+
138+
if VERSION < v"1.11-" && elt <: Complex
139+
# Broken because of type stability issue in Julia v1.10.
140+
@test_broken svd_compact(a)
141+
elseif arrayt === Array
142+
u, s, v = svd_compact(a)
143+
@test u * s * v a
144+
@test blocktype(u) === blocktype(a)
145+
@test blocktype(v) === blocktype(a)
146+
else
147+
@test_broken svd_compact(a)
148+
end
131149

132150
# Broken operations
133151
@test_broken inv(a)
134-
@test_broken svd_compact(a)
135152
@test_broken a[Block.(1:2), Block(2)]
136153
end

0 commit comments

Comments
 (0)