Skip to content

Commit 0f2daf7

Browse files
authored
Support for delta (#33)
1 parent 2c2d41e commit 0f2daf7

File tree

11 files changed

+392
-26
lines changed

11 files changed

+392
-26
lines changed

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.27"
4+
version = "0.1.28"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -16,22 +16,25 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1616
[weakdeps]
1717
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1818
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
19+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1920
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
2021

2122
[extensions]
2223
KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
24+
KroneckerArraysTensorAlgebraExt = "TensorAlgebra"
2325
KroneckerArraysTensorProductsExt = "TensorProducts"
2426

2527
[compat]
2628
Adapt = "4.3"
2729
BlockArrays = "1.6"
2830
BlockSparseArrays = "0.9"
29-
DerivableInterfaces = "0.5"
30-
DiagonalArrays = "0.3.5"
31+
DerivableInterfaces = "0.5.3"
32+
DiagonalArrays = "0.3.11"
3133
FillArrays = "1.13"
3234
GPUArraysCore = "0.2"
3335
LinearAlgebra = "1.10"
3436
MapBroadcast = "0.1.9"
3537
MatrixAlgebraKit = "0.2"
38+
TensorAlgebra = "0.3.10"
3639
TensorProducts = "0.1.7"
3740
julia = "1.10"

ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ using KroneckerArrays:
3939
_similar
4040

4141
function KroneckerArrays.arg1(r::AbstractBlockedUnitRange)
42-
return mortar_axis(arg2.(eachblockaxis(r)))
42+
return mortar_axis(arg1.(eachblockaxis(r)))
4343
end
4444
function KroneckerArrays.arg2(r::AbstractBlockedUnitRange)
4545
return mortar_axis(arg2.(eachblockaxis(r)))
@@ -56,15 +56,14 @@ function block_axes(ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N}) whe
5656
return block_axes(ax, Tuple(I)...)
5757
end
5858

59+
## TODO: Is this needed?
5960
function Base.getindex(
6061
a::ZeroBlocks{2,KroneckerMatrix{T,A,B}}, I::Vararg{Int,2}
6162
) where {T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}}
62-
ax_a1 = arg1.(a.parentaxes)
63+
ax_a1 = map(arg1, a.parentaxes)
6364
a1 = ZeroBlocks{2,A}(ax_a1)[I...]
64-
65-
ax_a2 = arg2.(a.parentaxes)
65+
ax_a2 = map(arg2, a.parentaxes)
6666
a2 = ZeroBlocks{2,B}(ax_a2)[I...]
67-
6867
return a1 a2
6968
end
7069
function Base.getindex(
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
module KroneckerArraysTensorAlgebraExt
2+
3+
using KroneckerArrays: KroneckerArrays, KroneckerArray, , arg1, arg2
4+
using TensorAlgebra:
5+
TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize
6+
7+
struct KroneckerFusion{A<:FusionStyle,B<:FusionStyle} <: FusionStyle
8+
a::A
9+
b::B
10+
end
11+
KroneckerArrays.arg1(style::KroneckerFusion) = style.a
12+
KroneckerArrays.arg2(style::KroneckerFusion) = style.b
13+
function TensorAlgebra.FusionStyle(a::KroneckerArray)
14+
return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a)))
15+
end
16+
function matricize_kronecker(
17+
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
18+
)
19+
return matricize(arg1(style), arg1(a), biperm) matricize(arg2(style), arg2(a), biperm)
20+
end
21+
function TensorAlgebra.matricize(
22+
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
23+
)
24+
return matricize_kronecker(style, a, biperm)
25+
end
26+
# Fix ambiguity error.
27+
# TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this.
28+
using TensorAlgebra: BlockedTrivialPermutation, unmatricize
29+
function TensorAlgebra.matricize(
30+
style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
31+
)
32+
return matricize_kronecker(style, a, biperm)
33+
end
34+
function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax)
35+
return unmatricize(arg1(style), arg1(a), arg1.(ax))
36+
unmatricize(arg2(style), arg2(a), arg2.(ax))
37+
end
38+
function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax)
39+
return unmatricize_kronecker(style, a, ax)
40+
end
41+
42+
end

src/cartesianproduct.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
9898
arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a))
9999
arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a))
100100

101+
function Base.getindex(a::CartesianProductUnitRange, i::CartesianProductUnitRange)
102+
prod = cartesianproduct(a)[cartesianproduct(i)]
103+
range = unproduct(a)[unproduct(i)]
104+
return cartesianrange(prod, range)
105+
end
106+
101107
function Base.show(io::IO, a::CartesianProductUnitRange)
102108
show(io, unproduct(a))
103109
return nothing

src/fillarrays/kroneckerarray.jl

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using FillArrays: FillArrays, Zeros
1+
using FillArrays: FillArrays, Ones, Zeros
22
function FillArrays.fillsimilar(
33
a::Zeros{T},
44
ax::Tuple{
@@ -21,6 +21,11 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
2121
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
2222
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
2323

24+
using DiagonalArrays: Delta
25+
const DeltaKronecker{T,N,A<:Delta{T,N},B<:AbstractArray{T,N}} = KroneckerArray{T,N,A,B}
26+
const KroneckerDelta{T,N,A<:AbstractArray{T,N},B<:Delta{T,N}} = KroneckerArray{T,N,A,B}
27+
const DeltaDelta{T,N,A<:Delta{T,N},B<:Delta{T,N}} = KroneckerArray{T,N,A,B}
28+
2429
_getindex(a::Eye, I1::Colon, I2::Colon) = a
2530
_getindex(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
2631
_getindex(a::Eye, I1::Base.Slice, I2::Colon) = a
@@ -30,15 +35,23 @@ _view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
3035
_view(a::Eye, I1::Base.Slice, I2::Colon) = a
3136
_view(a::Eye, I1::Colon, I2::Base.Slice) = a
3237

38+
function _getindex(a::Delta, I1::Union{Colon,Base.Slice}, Irest::Union{Colon,Base.Slice}...)
39+
return a
40+
end
41+
function _view(a::Delta, I1::Union{Colon,Base.Slice}, Irest::Union{Colon,Base.Slice}...)
42+
return a
43+
end
44+
3345
# Like `adapt` but preserves `Eye`.
3446
_adapt(to, a::Eye) = a
47+
_adapt(to, a::Delta) = a
3548

3649
# Allows customizing for `FillArrays.Eye`.
3750
function _convert(::Type{AbstractArray{T}}, a::RectDiagonal) where {T}
38-
_convert(AbstractMatrix{T}, a)
51+
return _convert(AbstractMatrix{T}, a)
3952
end
4053
function _convert(::Type{AbstractMatrix{T}}, a::RectDiagonal) where {T}
41-
RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a))
54+
return RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a))
4255
end
4356

4457
# Like `similar` but preserves `Eye`, `Ones`, etc.
@@ -61,8 +74,33 @@ function _similar(arrayt::Type{<:SquareEye}, axs::NTuple{2,AbstractUnitRange})
6174
return Eye{eltype(arrayt)}((only(unique(axs)),))
6275
end
6376

64-
# Like `copy` but preserves `Eye`.
77+
function _similar(a::Delta, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}})
78+
return Delta{elt}(axs)
79+
end
80+
function _similar(arrayt::Type{<:Delta}, axs::Tuple{Vararg{AbstractUnitRange}})
81+
return Delta{eltype(arrayt)}(axs)
82+
end
83+
84+
# Like `copy` but preserves `Eye`/`Delta`.
6585
_copy(a::Eye) = a
86+
_copy(a::Delta) = a
87+
88+
function _copyto!!(dest::Eye{<:Any,N}, src::Eye{<:Any,N}) where {N}
89+
size(dest) == size(src) ||
90+
throw(ArgumentError("Sizes do not match: $(size(dest)) != $(size(src))."))
91+
return dest
92+
end
93+
function _copyto!!(dest::Delta{<:Any,N}, src::Delta{<:Any,N}) where {N}
94+
size(dest) == size(src) ||
95+
throw(ArgumentError("Sizes do not match: $(size(dest)) != $(size(src))."))
96+
return dest
97+
end
98+
99+
function _permutedims!!(dest::Delta, src::Delta, perm)
100+
Base.PermutedDimsArrays.genperm(axes(src), perm) == axes(dest) ||
101+
throw(ArgumentError("Permuted axes do not match."))
102+
return dest
103+
end
66104

67105
using Base.Broadcast:
68106
AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted
@@ -75,10 +113,16 @@ end
75113
Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle()
76114
Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2
77115

116+
function _copyto!!(dest::Eye, src::Broadcasted{<:EyeStyle,<:Any,typeof(identity)})
117+
axes(dest) == axes(src) || error("Dimension mismatch.")
118+
return dest
119+
end
120+
78121
function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type)
79122
return Eye{elt}(axes(bc))
80123
end
81124

125+
# TODO: Define in terms of `_copyto!!` that is called on each argument.
82126
function Base.copyto!(dest::EyeKronecker, a::Sum{<:KroneckerStyle{<:Any,EyeStyle()}})
83127
dest2 = arg2(dest)
84128
f = LinearCombination(a)
@@ -99,6 +143,47 @@ function Base.copyto!(dest::EyeEye, a::Sum{<:KroneckerStyle{<:Any,EyeStyle(),Eye
99143
return error("Can't write in-place to `Eye ⊗ Eye`.")
100144
end
101145

146+
struct DeltaStyle{N} <: AbstractArrayStyle{N} end
147+
DeltaStyle(::Val{N}) where {N} = DeltaStyle{N}()
148+
DeltaStyle{M}(::Val{N}) where {M,N} = DeltaStyle{N}()
149+
function _BroadcastStyle(A::Type{<:Delta})
150+
return DeltaStyle{ndims(A)}()
151+
end
152+
Base.BroadcastStyle(style1::DeltaStyle, style2::DeltaStyle) = DeltaStyle()
153+
Base.BroadcastStyle(style1::DeltaStyle, style2::DefaultArrayStyle) = style2
154+
155+
function _copyto!!(dest::Delta, src::Broadcasted{<:DeltaStyle,<:Any,typeof(identity)})
156+
axes(dest) == axes(src) || error("Dimension mismatch.")
157+
return dest
158+
end
159+
160+
function Base.similar(bc::Broadcasted{<:DeltaStyle}, elt::Type)
161+
return Delta{elt}(axes(bc))
162+
end
163+
164+
# TODO: Dispatch on `DeltaStyle`.
165+
function Base.copyto!(dest::DeltaKronecker, a::Sum{<:KroneckerStyle})
166+
dest2 = arg2(dest)
167+
f = LinearCombination(a)
168+
args = arguments(a)
169+
arg2s = arg2.(args)
170+
dest2 .= f.(arg2s...)
171+
return dest
172+
end
173+
# TODO: Dispatch on `DeltaStyle`.
174+
function Base.copyto!(dest::KroneckerDelta, a::Sum{<:KroneckerStyle})
175+
dest1 = arg1(dest)
176+
f = LinearCombination(a)
177+
args = arguments(a)
178+
arg1s = arg1.(args)
179+
dest1 .= f.(arg1s...)
180+
return dest
181+
end
182+
# TODO: Dispatch on `DeltaStyle`.
183+
function Base.copyto!(dest::DeltaDelta, a::Sum{<:KroneckerStyle})
184+
return error("Can't write in-place to `Delta ⊗ Delta`.")
185+
end
186+
102187
# Simplification rules similar to those for FillArrays.jl:
103188
# https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl
104189
using FillArrays: Zeros

src/kroneckerarray.jl

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,26 @@ _copy(a::AbstractArray) = copy(a)
4343
function Base.copy(a::KroneckerArray)
4444
return _copy(arg1(a)) _copy(arg2(a))
4545
end
46-
function Base.copyto!(dest::KroneckerArray, src::KroneckerArray)
47-
copyto!(arg1(dest), arg1(src))
48-
copyto!(arg2(dest), arg2(src))
46+
47+
# Allows extra customization, like for `FillArrays.Eye`.
48+
function _copyto!!(dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) where {N}
49+
copyto!(dest, src)
50+
return dest
51+
end
52+
function _copyto!!(dest::AbstractArray, src::Broadcasted)
53+
copyto!(dest, src)
54+
return dest
55+
end
56+
57+
function Base.copyto!(dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N}) where {N}
58+
return copyto!_kronecker(dest, src)
59+
end
60+
function copyto!_kronecker(
61+
dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N}
62+
) where {N}
63+
# TODO: Check if neither argument is mutated and if so error.
64+
_copyto!!(arg1(dest), arg1(src))
65+
_copyto!!(arg2(dest), arg2(src))
4966
return dest
5067
end
5168

@@ -110,6 +127,23 @@ function Base.similar(
110127
return similar(promote_type(A, B), sz)
111128
end
112129

130+
function _permutedims!!(dest::AbstractArray, src::AbstractArray, perm)
131+
permutedims!(dest, src, perm)
132+
return dest
133+
end
134+
135+
using DerivableInterfaces: DerivableInterfaces, permuteddims
136+
function DerivableInterfaces.permuteddims(a::KroneckerArray, perm)
137+
return permuteddims(arg1(a), perm) permuteddims(arg2(a), perm)
138+
end
139+
140+
function Base.permutedims!(dest::KroneckerArray, src::KroneckerArray, perm)
141+
# TODO: Error if neither argument is mutable.
142+
_permutedims!!(arg1(dest), arg1(src), perm)
143+
_permutedims!!(arg2(dest), arg2(src), perm)
144+
return dest
145+
end
146+
113147
function flatten(t::Tuple{Tuple,Tuple,Vararg{Tuple}})
114148
return (t[1]..., flatten(Base.tail(t))...)
115149
end
@@ -128,7 +162,7 @@ function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N}
128162
a′ = reshape(a, interleave(size(a), ntuple(one, N)))
129163
b′ = reshape(b, interleave(ntuple(one, N), size(b)))
130164
c′ = permutedims(a′ .* b′, reverse(ntuple(identity, 2N)))
131-
sz = ntuple(i -> size(a, i) * size(b, i), N)
165+
sz = reverse(ntuple(i -> size(a, i) * size(b, i), N))
132166
return permutedims(reshape(c′, sz), reverse(ntuple(identity, N)))
133167
end
134168
kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b)
@@ -284,6 +318,12 @@ for f in [:transpose, :adjoint, :inv]
284318
end
285319
end
286320

321+
function Base.reshape(
322+
a::KroneckerArray, ax::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}}
323+
)
324+
return reshape(arg1(a), map(arg1, ax)) reshape(arg2(a), map(arg2, ax))
325+
end
326+
287327
# Allows for customizations for FillArrays.
288328
_BroadcastStyle(x) = BroadcastStyle(x)
289329

@@ -405,8 +445,8 @@ Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a)
405445
Broadcast.broadcastable(a::KroneckerBroadcasted) = a
406446
Base.copy(a::KroneckerBroadcasted) = copy(arg1(a)) copy(arg2(a))
407447
function Base.copyto!(dest::KroneckerArray, a::KroneckerBroadcasted)
408-
copyto!(arg1(dest), copy(arg1(a)))
409-
copyto!(arg2(dest), copy(arg2(a)))
448+
_copyto!!(arg1(dest), arg1(a))
449+
_copyto!!(arg2(dest), arg2(a))
410450
return dest
411451
end
412452
function Base.eltype(a::KroneckerBroadcasted)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1414
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1515
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1616
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
17+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1718
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
1819
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1920
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
@@ -34,6 +35,7 @@ MatrixAlgebraKit = "0.2"
3435
SafeTestsets = "0.1"
3536
StableRNGs = "1.0"
3637
Suppressor = "0.2"
38+
TensorAlgebra = "0.3.10"
3739
TensorProducts = "0.1.7"
3840
Test = "1.10"
3941
TestExtras = "0.3"

0 commit comments

Comments
 (0)