Skip to content

Commit 986e780

Browse files
authored
More functionality (#2)
1 parent 567ca4f commit 986e780

File tree

4 files changed

+358
-7
lines changed

4 files changed

+358
-7
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
7+
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
8+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
79
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
811

912
[compat]
13+
DerivableInterfaces = "0.4.5"
14+
GPUArraysCore = "0.2.0"
1015
LinearAlgebra = "1.10"
16+
MatrixAlgebraKit = "0.2.0"
1117
julia = "1.10"

src/KroneckerArrays.jl

Lines changed: 233 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module KroneckerArrays
22

3+
using GPUArraysCore: GPUArraysCore
4+
35
export , ×
46

57
struct CartesianProduct{A,B}
@@ -28,6 +30,26 @@ end
2830
Base.first(r::CartesianProductUnitRange) = first(r.range)
2931
Base.last(r::CartesianProductUnitRange) = last(r.range)
3032

33+
function Base.axes(r::CartesianProductUnitRange)
34+
return (CartesianProductUnitRange(r.product, only(axes(r.range))),)
35+
end
36+
37+
using Base.Broadcast: DefaultArrayStyle
38+
for f in (:+, :-)
39+
@eval begin
40+
function Broadcast.broadcasted(
41+
::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer
42+
)
43+
return CartesianProductUnitRange(r.product, $f.(r.range, x))
44+
end
45+
function Broadcast.broadcasted(
46+
::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange
47+
)
48+
return CartesianProductUnitRange(r.product, $f.(x, r.range))
49+
end
50+
end
51+
end
52+
3153
struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N}
3254
a::A
3355
b::B
@@ -44,6 +66,15 @@ end
4466
const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B}
4567
const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B}
4668

69+
function Base.copy(a::KroneckerArray)
70+
return copy(a.a) copy(a.b)
71+
end
72+
function Base.copyto!(dest::KroneckerArray, src::KroneckerArray)
73+
copyto!(dest.a, src.a)
74+
copyto!(dest.b, src.b)
75+
return dest
76+
end
77+
4778
function Base.similar(
4879
a::AbstractArray,
4980
elt::Type,
@@ -73,9 +104,21 @@ function Base.similar(
73104
return similar(arrayt, map(ax -> ax.product.a, axs))
74105
similar(arrayt, map(ax -> ax.product.b, axs))
75106
end
107+
function Base.similar(
108+
arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}},
109+
axs::Tuple{
110+
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
111+
},
112+
) where {A,B}
113+
return similar(A, map(ax -> ax.product.a, axs)) similar(B, map(ax -> ax.product.b, axs))
114+
end
76115

77116
Base.collect(a::KroneckerArray) = kron(a.a, a.b)
78117

118+
function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N}
119+
return convert(Array{T,N}, collect(a))
120+
end
121+
79122
Base.size(a::KroneckerArray) = ntuple(dim -> size(a.a, dim) * size(a.b, dim), ndims(a))
80123

81124
function Base.axes(a::KroneckerArray)
@@ -107,12 +150,23 @@ end
107150
(a::Number, b::AbstractVecOrMat) = a * b
108151
(a::AbstractVecOrMat, b::Number) = a * b
109152

110-
function Base.getindex(::KroneckerArray, ::Int)
111-
return throw(ArgumentError("Scalar indexing of KroneckerArray is not supported."))
153+
function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer)
154+
GPUArraysCore.assertscalar("getindex")
155+
# Code logic from Kronecker.jl:
156+
# https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105
157+
k, l = size(a.b)
158+
return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1]
159+
end
160+
function Base.getindex(a::KroneckerMatrix, i::Integer)
161+
return a[CartesianIndices(a)[i]]
112162
end
113-
function Base.getindex(::KroneckerArray{<:Any,N}, ::Vararg{Int,N}) where {N}
114-
return throw(ArgumentError("Scalar indexing of KroneckerArray is not supported."))
163+
164+
function Base.getindex(a::KroneckerVector, i::Integer)
165+
GPUArraysCore.assertscalar("getindex")
166+
k = length(a.b)
167+
return a.a[cld(i, k)] * a.b[(i - 1) % k + 1]
115168
end
169+
116170
function Base.getindex(a::KroneckerVector, i::CartesianProduct)
117171
return a.a[i.a] a.b[i.b]
118172
end
@@ -169,9 +223,18 @@ end
169223
function Base.:*(a::KroneckerArray, b::KroneckerArray)
170224
return (a.a * b.a) (a.b * b.b)
171225
end
172-
function LinearAlgebra.mul!(c::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
226+
function LinearAlgebra.mul!(
227+
c::KroneckerArray, a::KroneckerArray, b::KroneckerArray, α::Number, β::Number
228+
)
229+
iszero(β) ||
230+
iszero(c) ||
231+
throw(
232+
ArgumentError(
233+
"Can't multiple KroneckerArrays with nonzero β and nonzero destination."
234+
),
235+
)
173236
mul!(c.a, a.a, b.a)
174-
mul!(c.b, a.b, b.b)
237+
mul!(c.b, a.b, b.b, α, β)
175238
return c
176239
end
177240
function LinearAlgebra.tr(a::KroneckerArray)
@@ -269,4 +332,168 @@ for op in (:+, :-)
269332
end
270333
end
271334

335+
function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray)
336+
dest.a .= a.a
337+
dest.b .= a.b
338+
return dest
339+
end
340+
function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
341+
if a.b == b.b
342+
map!(+, dest.a, a.a, b.a)
343+
dest.b .= a.b
344+
elseif a.a == b.a
345+
dest.a .= a.a
346+
map!(+, dest.b, a.b, b.b)
347+
else
348+
throw(
349+
ArgumentError(
350+
"KroneckerArray addition is only supported when the first or second arguments match.",
351+
),
352+
)
353+
end
354+
return dest
355+
end
356+
function Base.map!(
357+
f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
358+
)
359+
dest.a .= f.x .* a.a
360+
dest.b .= a.b
361+
return dest
362+
end
363+
function Base.map!(
364+
f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
365+
)
366+
dest.a .= a.a
367+
dest.b .= a.b .* f.x
368+
return dest
369+
end
370+
371+
using DerivableInterfaces: DerivableInterfaces, zero!
372+
function DerivableInterfaces.zero!(a::KroneckerArray)
373+
zero!(a.a)
374+
zero!(a.b)
375+
return a
376+
end
377+
378+
using MatrixAlgebraKit:
379+
MatrixAlgebraKit,
380+
AbstractAlgorithm,
381+
TruncationStrategy,
382+
default_eig_algorithm,
383+
default_eigh_algorithm,
384+
default_lq_algorithm,
385+
default_polar_algorithm,
386+
default_qr_algorithm,
387+
default_svd_algorithm,
388+
eig_full!,
389+
eig_trunc!,
390+
eig_vals!,
391+
eigh_full!,
392+
eigh_trunc!,
393+
eigh_vals!,
394+
initialize_output,
395+
left_null!,
396+
left_orth!,
397+
left_polar!,
398+
lq_compact!,
399+
lq_full!,
400+
qr_compact!,
401+
qr_full!,
402+
right_null!,
403+
right_orth!,
404+
right_polar!,
405+
svd_compact!,
406+
svd_full!,
407+
svd_trunc!,
408+
svd_vals!,
409+
truncate!
410+
411+
struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm
412+
a::A
413+
b::B
414+
end
415+
416+
for f in (:eig, :eigh, :lq, :qr, :polar, :svd)
417+
ff = Symbol("default_", f, "_algorithm")
418+
@eval begin
419+
function MatrixAlgebraKit.$ff(a::KroneckerMatrix; kwargs...)
420+
return KroneckerAlgorithm($ff(a.a; kwargs...), $ff(a.b; kwargs...))
421+
end
422+
end
423+
end
424+
425+
for f in (
426+
:eig_full!,
427+
:eigh_full!,
428+
:qr_compact!,
429+
:qr_full!,
430+
:left_polar!,
431+
:lq_compact!,
432+
:lq_full!,
433+
:right_polar!,
434+
:svd_compact!,
435+
:svd_full!,
436+
)
437+
@eval begin
438+
function MatrixAlgebraKit.initialize_output(
439+
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
440+
)
441+
return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b)
442+
end
443+
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs...)
444+
$f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs...)
445+
$f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs...)
446+
return F
447+
end
448+
end
449+
end
450+
451+
for f in (:eig_vals!, :eigh_vals!, :svd_vals!)
452+
@eval begin
453+
function MatrixAlgebraKit.initialize_output(
454+
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
455+
)
456+
return initialize_output($f, a.a, alg.a) initialize_output($f, a.b, alg.b)
457+
end
458+
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm)
459+
$f(a.a, F.a, alg.a)
460+
$f(a.b, F.b, alg.b)
461+
return F
462+
end
463+
end
464+
end
465+
466+
for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!)
467+
@eval begin
468+
function MatrixAlgebraKit.truncate!(
469+
::typeof($f),
470+
(D, V)::Tuple{KroneckerMatrix,KroneckerMatrix},
471+
strategy::TruncationStrategy,
472+
)
473+
return throw(MethodError(truncate!, ($f, (D, V), strategy)))
474+
end
475+
end
476+
end
477+
478+
for f in (:left_orth!, :right_orth!)
479+
@eval begin
480+
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
481+
return initialize_output($f, a.a) .⊗ initialize_output($f, a.b)
482+
end
483+
end
484+
end
485+
486+
for f in (:left_null!, :right_null!)
487+
@eval begin
488+
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
489+
return initialize_output($f, a.a) initialize_output($f, a.b)
490+
end
491+
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs...)
492+
$f(a.a, F.a; kwargs...)
493+
$f(a.b, F.b; kwargs...)
494+
return F
495+
end
496+
end
497+
end
498+
272499
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
44
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5+
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
56
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
67
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
78
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

0 commit comments

Comments
 (0)