Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
*.jl.*.cov
*.jl.cov
*.jl.mem
*.cov
*.mem
*.o
*.swp
.DS_Store
.benchmarkci
.tmp
.vscode/
Manifest.toml
LocalPreferences.toml
Manifest*.toml
benchmark/*.json
dev/
docs/LocalPreferences.toml
docs/Manifest.toml
docs/build/
docs/src/index.md
examples/LocalPreferences.toml
test/LocalPreferences.toml
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
version = "0.3.2"
version = "0.3.3"
authors = ["ITensor developers <[email protected]> and contributors"]

[deps]
Expand Down Expand Up @@ -36,7 +36,7 @@ GPUArraysCore = "0.2"
LinearAlgebra = "1.10"
MapBroadcast = "0.1.10"
MatrixAlgebraKit = "0.6"
TensorAlgebra = "0.3.10, 0.4"
TensorAlgebra = "0.5"
TensorProducts = "0.1.7"
TypeParameterAccessors = "0.4.2"
julia = "1.10"
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module KroneckerArraysTensorAlgebraExt

using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, ⊗, kroneckerfactors
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation, FusionStyle,
matricize, unmatricize
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation,
FusionStyle, matricize, unmatricize

struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle
a::A
Expand All @@ -11,32 +11,44 @@ end
KroneckerArrays.kroneckerfactors(style::KroneckerFusion) = (style.a, style.b)
KroneckerArrays.kroneckerfactortypes(::Type{KroneckerFusion{A, B}}) where {A, B} = (A, B)

TensorAlgebra.FusionStyle(a::AbstractKroneckerArray) = KroneckerFusion(FusionStyle.(kroneckerfactors(a))...)
function TensorAlgebra.FusionStyle(a::AbstractKroneckerArray)
return KroneckerFusion(FusionStyle.(kroneckerfactors(a))...)
end
function matricize_kronecker(
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
style::FusionStyle, a::AbstractArray, length1::Val, length2::Val
)
return matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), biperm) ⊗
matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), biperm)
m1 = matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), length1, length2)
m2 = matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), length1, length2)
return m1 ⊗ m2
end
function TensorAlgebra.matricize(
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
style::KroneckerFusion, a::AbstractArray, length1::Val, length2::Val
)
return matricize_kronecker(style, a, biperm)
return matricize_kronecker(style, a, length1, length2)
end
# Fix ambiguity error.
# TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this.
using TensorAlgebra: BlockedTrivialPermutation, unmatricize
function TensorAlgebra.matricize(
style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
function unmatricize_kronecker(
style::FusionStyle,
m::AbstractMatrix,
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
domain_axes::Tuple{Vararg{AbstractUnitRange}},
)
return matricize_kronecker(style, a, biperm)
style1, style2 = kroneckerfactors(style)
m1, m2 = kroneckerfactors(m)
codomain1 = kroneckerfactors.(codomain_axes, 1)
codomain2 = kroneckerfactors.(codomain_axes, 2)
domain1 = kroneckerfactors.(domain_axes, 1)
domain2 = kroneckerfactors.(domain_axes, 2)
a1 = unmatricize(style1, m1, codomain1, domain1)
a2 = unmatricize(style2, m2, codomain2, domain2)
return a1 ⊗ a2
end
function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax)
return unmatricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), kroneckerfactors.(ax, 1)) ⊗
unmatricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), kroneckerfactors.(ax, 2))
end
function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax)
return unmatricize_kronecker(style, a, ax)
function TensorAlgebra.unmatricize(
style::KroneckerFusion,
m::AbstractMatrix,
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
domain_axes::Tuple{Vararg{AbstractUnitRange}},
)
return unmatricize_kronecker(style, m, codomain_axes, domain_axes)
end

end
5 changes: 4 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[sources]
KroneckerArrays = {path = ".."}

[compat]
Adapt = "4"
Aqua = "0.8"
Expand All @@ -35,7 +38,7 @@ MatrixAlgebraKit = "0.6"
SafeTestsets = "0.1"
StableRNGs = "1.0"
Suppressor = "0.2"
TensorAlgebra = "0.3.10, 0.4"
TensorAlgebra = "0.5"
TensorProducts = "0.1.7"
Test = "1.10"
TestExtras = "0.3"
Loading