Skip to content

Commit f532818

Browse files
authored
Upgrade to TensorAlgebra v0.5 (#60)
1 parent 75fff94 commit f532818

File tree

4 files changed

+42
-31
lines changed

4 files changed

+42
-31
lines changed

.gitignore

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1-
*.jl.*.cov
2-
*.jl.cov
3-
*.jl.mem
1+
*.cov
2+
*.mem
43
*.o
54
*.swp
65
.DS_Store
76
.benchmarkci
87
.tmp
98
.vscode/
10-
Manifest.toml
9+
LocalPreferences.toml
10+
Manifest*.toml
1111
benchmark/*.json
1212
dev/
13-
docs/LocalPreferences.toml
14-
docs/Manifest.toml
1513
docs/build/
1614
docs/src/index.md
17-
examples/LocalPreferences.toml
18-
test/LocalPreferences.toml

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
3-
version = "0.3.2"
3+
version = "0.3.3"
44
authors = ["ITensor developers <[email protected]> and contributors"]
55

66
[deps]
@@ -36,7 +36,7 @@ GPUArraysCore = "0.2"
3636
LinearAlgebra = "1.10"
3737
MapBroadcast = "0.1.10"
3838
MatrixAlgebraKit = "0.6"
39-
TensorAlgebra = "0.3.10, 0.4"
39+
TensorAlgebra = "0.5"
4040
TensorProducts = "0.1.7"
4141
TypeParameterAccessors = "0.4.2"
4242
julia = "1.10"
Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module KroneckerArraysTensorAlgebraExt
22

33
using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, , kroneckerfactors
4-
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation, FusionStyle,
5-
matricize, unmatricize
4+
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation,
5+
FusionStyle, matricize, unmatricize
66

77
struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle
88
a::A
@@ -11,32 +11,44 @@ end
1111
KroneckerArrays.kroneckerfactors(style::KroneckerFusion) = (style.a, style.b)
1212
KroneckerArrays.kroneckerfactortypes(::Type{KroneckerFusion{A, B}}) where {A, B} = (A, B)
1313

14-
TensorAlgebra.FusionStyle(a::AbstractKroneckerArray) = KroneckerFusion(FusionStyle.(kroneckerfactors(a))...)
14+
function TensorAlgebra.FusionStyle(a::AbstractKroneckerArray)
15+
return KroneckerFusion(FusionStyle.(kroneckerfactors(a))...)
16+
end
1517
function matricize_kronecker(
16-
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
18+
style::FusionStyle, a::AbstractArray, length1::Val, length2::Val
1719
)
18-
return matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), biperm)
19-
matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), biperm)
20+
m1 = matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), length1, length2)
21+
m2 = matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), length1, length2)
22+
return m1 m2
2023
end
2124
function TensorAlgebra.matricize(
22-
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
25+
style::KroneckerFusion, a::AbstractArray, length1::Val, length2::Val
2326
)
24-
return matricize_kronecker(style, a, biperm)
27+
return matricize_kronecker(style, a, length1, length2)
2528
end
26-
# Fix ambiguity error.
27-
# TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this.
28-
using TensorAlgebra: BlockedTrivialPermutation, unmatricize
29-
function TensorAlgebra.matricize(
30-
style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
29+
function unmatricize_kronecker(
30+
style::FusionStyle,
31+
m::AbstractMatrix,
32+
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
33+
domain_axes::Tuple{Vararg{AbstractUnitRange}},
3134
)
32-
return matricize_kronecker(style, a, biperm)
35+
style1, style2 = kroneckerfactors(style)
36+
m1, m2 = kroneckerfactors(m)
37+
codomain1 = kroneckerfactors.(codomain_axes, 1)
38+
codomain2 = kroneckerfactors.(codomain_axes, 2)
39+
domain1 = kroneckerfactors.(domain_axes, 1)
40+
domain2 = kroneckerfactors.(domain_axes, 2)
41+
a1 = unmatricize(style1, m1, codomain1, domain1)
42+
a2 = unmatricize(style2, m2, codomain2, domain2)
43+
return a1 a2
3344
end
34-
function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax)
35-
return unmatricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), kroneckerfactors.(ax, 1))
36-
unmatricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), kroneckerfactors.(ax, 2))
37-
end
38-
function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax)
39-
return unmatricize_kronecker(style, a, ax)
45+
function TensorAlgebra.unmatricize(
46+
style::KroneckerFusion,
47+
m::AbstractMatrix,
48+
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
49+
domain_axes::Tuple{Vararg{AbstractUnitRange}},
50+
)
51+
return unmatricize_kronecker(style, m, codomain_axes, domain_axes)
4052
end
4153

4254
end

test/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
1919
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2020
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
2121

22+
[sources]
23+
KroneckerArrays = {path = ".."}
24+
2225
[compat]
2326
Adapt = "4"
2427
Aqua = "0.8"
@@ -35,7 +38,7 @@ MatrixAlgebraKit = "0.6"
3538
SafeTestsets = "0.1"
3639
StableRNGs = "1.0"
3740
Suppressor = "0.2"
38-
TensorAlgebra = "0.3.10, 0.4"
41+
TensorAlgebra = "0.5"
3942
TensorProducts = "0.1.7"
4043
Test = "1.10"
4144
TestExtras = "0.3"

0 commit comments

Comments
 (0)