Skip to content

Commit b5c121a

Browse files
committed
More functionality
1 parent 241c9c3 commit b5c121a

File tree

4 files changed

+253
-7
lines changed

4 files changed

+253
-7
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
7+
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
8+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
79
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
811

912
[compat]
13+
DerivableInterfaces = "0.4.5"
14+
GPUArraysCore = "0.2.0"
1015
LinearAlgebra = "1.10"
16+
MatrixAlgebraKit = "0.2.0"
1117
julia = "1.10"

src/KroneckerArrays.jl

Lines changed: 198 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module KroneckerArrays
22

3+
using GPUArraysCore: GPUArraysCore
4+
35
export , ×
46

57
struct CartesianProduct{A,B}
@@ -28,6 +30,26 @@ end
2830
Base.first(r::CartesianProductUnitRange) = first(r.range)
2931
Base.last(r::CartesianProductUnitRange) = last(r.range)
3032

33+
function Base.axes(r::CartesianProductUnitRange)
34+
return (CartesianProductUnitRange(r.product, only(axes(r.range))),)
35+
end
36+
37+
using Base.Broadcast: DefaultArrayStyle
38+
for f in (:+, :-)
39+
@eval begin
40+
function Broadcast.broadcasted(
41+
::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer
42+
)
43+
return CartesianProductUnitRange(r.product, $f.(r.range, x))
44+
end
45+
function Broadcast.broadcasted(
46+
::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange
47+
)
48+
return CartesianProductUnitRange(r.product, $f.(x, r.range))
49+
end
50+
end
51+
end
52+
3153
struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N}
3254
a::A
3355
b::B
@@ -44,6 +66,15 @@ end
4466
const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B}
4567
const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B}
4668

69+
function Base.copy(a::KroneckerArray)
70+
return copy(a.a) copy(a.b)
71+
end
72+
function Base.copyto!(dest::KroneckerArray, src::KroneckerArray)
73+
copyto!(dest.a, src.a)
74+
copyto!(dest.b, src.b)
75+
return dest
76+
end
77+
4778
function Base.similar(
4879
a::AbstractArray,
4980
elt::Type,
@@ -73,9 +104,21 @@ function Base.similar(
73104
return similar(arrayt, map(ax -> ax.product.a, axs))
74105
similar(arrayt, map(ax -> ax.product.b, axs))
75106
end
107+
function Base.similar(
108+
arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}},
109+
axs::Tuple{
110+
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
111+
},
112+
) where {A,B}
113+
return similar(A, map(ax -> ax.product.a, axs)) similar(B, map(ax -> ax.product.b, axs))
114+
end
76115

77116
Base.collect(a::KroneckerArray) = kron(a.a, a.b)
78117

118+
function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N}
119+
return convert(Array{T,N}, collect(a))
120+
end
121+
79122
Base.size(a::KroneckerArray) = ntuple(dim -> size(a.a, dim) * size(a.b, dim), ndims(a))
80123

81124
function Base.axes(a::KroneckerArray)
@@ -107,12 +150,23 @@ end
107150
(a::Number, b::AbstractVecOrMat) = a * b
108151
(a::AbstractVecOrMat, b::Number) = a * b
109152

110-
function Base.getindex(::KroneckerArray, ::Int)
111-
return throw(ArgumentError("Scalar indexing of KroneckerArray is not supported."))
153+
function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer)
154+
GPUArraysCore.assertscalar("getindex")
155+
# Code logic from Kronecker.jl:
156+
# https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105
157+
k, l = size(a.b)
158+
return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1]
112159
end
113-
function Base.getindex(::KroneckerArray{<:Any,N}, ::Vararg{Int,N}) where {N}
114-
return throw(ArgumentError("Scalar indexing of KroneckerArray is not supported."))
160+
function Base.getindex(a::KroneckerMatrix, i::Integer)
161+
return a[CartesianIndices(a)[i]]
115162
end
163+
164+
function Base.getindex(a::KroneckerVector, i::Integer)
165+
GPUArraysCore.assertscalar("getindex")
166+
k = length(a.b)
167+
return a.a[cld(i, k)] * a.b[(i - 1) % k + 1]
168+
end
169+
116170
function Base.getindex(a::KroneckerVector, i::CartesianProduct)
117171
return a.a[i.a] a.b[i.b]
118172
end
@@ -169,9 +223,18 @@ end
169223
function Base.:*(a::KroneckerArray, b::KroneckerArray)
170224
return (a.a * b.a) (a.b * b.b)
171225
end
172-
function LinearAlgebra.mul!(c::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
226+
function LinearAlgebra.mul!(
227+
c::KroneckerArray, a::KroneckerArray, b::KroneckerArray, α::Number, β::Number
228+
)
229+
iszero(β) ||
230+
iszero(c) ||
231+
throw(
232+
ArgumentError(
233+
"Can't multiple KroneckerArrays with nonzero β and nonzero destination."
234+
),
235+
)
173236
mul!(c.a, a.a, b.a)
174-
mul!(c.b, a.b, b.b)
237+
mul!(c.b, a.b, b.b, α, β)
175238
return c
176239
end
177240
function LinearAlgebra.tr(a::KroneckerArray)
@@ -269,4 +332,133 @@ for op in (:+, :-)
269332
end
270333
end
271334

335+
function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray)
336+
dest.a .= a.a
337+
dest.b .= a.b
338+
return dest
339+
end
340+
function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
341+
if a.b == b.b
342+
map!(+, dest.a, a.a, b.a)
343+
dest.b .= a.b
344+
elseif a.a == b.a
345+
dest.a .= a.a
346+
map!(+, dest.b, a.b, b.b)
347+
else
348+
throw(
349+
ArgumentError(
350+
"KroneckerArray addition is only supported when the first or second arguments match.",
351+
),
352+
)
353+
end
354+
return dest
355+
end
356+
function Base.map!(
357+
f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
358+
)
359+
dest.a .= f.x .* a.a
360+
dest.b .= a.b
361+
return dest
362+
end
363+
function Base.map!(
364+
f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
365+
)
366+
dest.a .= a.a
367+
dest.b .= a.b .* f.x
368+
return dest
369+
end
370+
371+
using DerivableInterfaces: DerivableInterfaces, zero!
372+
function DerivableInterfaces.zero!(a::KroneckerArray)
373+
zero!(a.a)
374+
zero!(a.b)
375+
return a
376+
end
377+
378+
using MatrixAlgebraKit:
379+
MatrixAlgebraKit,
380+
AbstractAlgorithm,
381+
TruncationStrategy,
382+
default_eig_algorithm,
383+
default_eigh_algorithm,
384+
eig_full!,
385+
eig_trunc!,
386+
eig_vals!,
387+
eigh_full!,
388+
eigh_trunc!,
389+
eigh_vals!,
390+
initialize_output,
391+
truncate!
392+
393+
struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm
394+
a::A
395+
b::B
396+
end
397+
398+
function MatrixAlgebraKit.default_eig_algorithm(a::KroneckerMatrix)
399+
return KroneckerAlgorithm(default_eig_algorithm(a.a), default_eig_algorithm(a.b))
400+
end
401+
function MatrixAlgebraKit.initialize_output(
402+
f::typeof(eig_full!), a::KroneckerMatrix, alg::KroneckerAlgorithm
403+
)
404+
return initialize_output(f, a.a, alg.a) .⊗ initialize_output(f, a.b, alg.b)
405+
end
406+
function MatrixAlgebraKit.eig_full!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm)
407+
eig_full!(a.a, Base.Fix2(getfield, :a).(F), alg.a)
408+
eig_full!(a.b, Base.Fix2(getfield, :b).(F), alg.b)
409+
return F
410+
end
411+
412+
function MatrixAlgebraKit.truncate!(
413+
::typeof(eig_trunc!),
414+
(D, V)::Tuple{KroneckerMatrix,KroneckerMatrix},
415+
strategy::TruncationStrategy,
416+
)
417+
return throw(MethodError(truncate!, (eig_trunc!, (D, V), strategy)))
418+
end
419+
420+
function MatrixAlgebraKit.initialize_output(
421+
f::typeof(eig_vals!), a::KroneckerMatrix, alg::KroneckerAlgorithm
422+
)
423+
return initialize_output(f, a.a, alg.a) initialize_output(f, a.b, alg.b)
424+
end
425+
function MatrixAlgebraKit.eig_vals!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm)
426+
eig_vals!(a.a, F.a, alg.a)
427+
eig_vals!(a.b, F.b, alg.b)
428+
return F
429+
end
430+
431+
function MatrixAlgebraKit.default_eigh_algorithm(a::KroneckerMatrix)
432+
return KroneckerAlgorithm(default_eigh_algorithm(a.a), default_eigh_algorithm(a.b))
433+
end
434+
function MatrixAlgebraKit.initialize_output(
435+
f::typeof(eigh_full!), a::KroneckerMatrix, alg::KroneckerAlgorithm
436+
)
437+
return initialize_output(f, a.a, alg.a) .⊗ initialize_output(f, a.b, alg.b)
438+
end
439+
function MatrixAlgebraKit.eigh_full!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm)
440+
eigh_full!(a.a, Base.Fix2(getfield, :a).(F), alg.a)
441+
eigh_full!(a.b, Base.Fix2(getfield, :b).(F), alg.b)
442+
return F
443+
end
444+
445+
function MatrixAlgebraKit.truncate!(
446+
::typeof(eigh_trunc!),
447+
(D, V)::Tuple{KroneckerMatrix,KroneckerMatrix},
448+
strategy::TruncationStrategy,
449+
)
450+
return throw(MethodError(truncate!, (eigh_trunc!, (D, V), strategy)))
451+
end
452+
453+
function MatrixAlgebraKit.initialize_output(
454+
f::typeof(eigh_vals!), a::KroneckerMatrix, alg::KroneckerAlgorithm
455+
)
456+
return initialize_output(f, a.a, alg.a) initialize_output(f, a.b, alg.b)
457+
end
458+
function MatrixAlgebraKit.eigh_vals!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm)
459+
eigh_vals!(a.a, F.a, alg.a)
460+
eigh_vals!(a.b, F.b, alg.b)
461+
return F
462+
end
463+
272464
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
44
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5+
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
56
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
67
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
78
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/test_matrixalgebrakit.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using KroneckerArrays:
2+
using LinearAlgebra: Hermitian, diag
3+
using MatrixAlgebraKit:
4+
eig_full,
5+
eig_trunc,
6+
eig_vals,
7+
eigh_full,
8+
eigh_trunc,
9+
eigh_vals,
10+
left_null,
11+
left_orth,
12+
left_polar,
13+
lq_compact,
14+
lq_full,
15+
qr_compact,
16+
qr_full,
17+
right_null,
18+
right_orth,
19+
right_polar,
20+
svd_compact,
21+
svd_full,
22+
svd_trunc,
23+
svd_vals
24+
using Test: @test, @test_throws, @testset
25+
26+
@testset "MatrixAlgebraKit" begin
27+
x = randn(2, 2)
28+
y = randn(3, 3)
29+
a = x y
30+
ah = Hermitian(x) Hermitian(y)
31+
32+
d, v = eig_full(a)
33+
@test a * v v * d
34+
35+
@test_throws MethodError eig_trunc(a)
36+
37+
d = eig_vals(a)
38+
@test d diag(eig_full(a)[1])
39+
40+
d, v = eigh_full(ah)
41+
@test ah * v v * d
42+
43+
@test_throws MethodError eigh_trunc(ah)
44+
45+
d = eigh_vals(ah)
46+
@test d diag(eigh_full(ah)[1])
47+
end

0 commit comments

Comments
 (0)