Skip to content

Commit 135f731

Browse files
authored
Upgrade to TensorAlgebra v0.6 (#63)
1 parent f532818 commit 135f731

File tree

6 files changed

+52
-56
lines changed

6 files changed

+52
-56
lines changed

Project.toml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
3-
version = "0.3.3"
43
authors = ["ITensor developers <[email protected]> and contributors"]
4+
version = "0.3.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -18,12 +18,10 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1818
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1919
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
2020
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
21-
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
2221

2322
[extensions]
2423
KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
2524
KroneckerArraysTensorAlgebraExt = "TensorAlgebra"
26-
KroneckerArraysTensorProductsExt = "TensorProducts"
2725

2826
[compat]
2927
Adapt = "4.3"
@@ -36,7 +34,6 @@ GPUArraysCore = "0.2"
3634
LinearAlgebra = "1.10"
3735
MapBroadcast = "0.1.10"
3836
MatrixAlgebraKit = "0.6"
39-
TensorAlgebra = "0.5"
40-
TensorProducts = "0.1.7"
37+
TensorAlgebra = "0.6.2"
4138
TypeParameterAccessors = "0.4.2"
4239
julia = "1.10"

ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module KroneckerArraysTensorAlgebraExt
22

3-
using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, , kroneckerfactors
3+
using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, CartesianProductUnitRange,
4+
, cartesianrange, kroneckerfactors, kroneckerfactortypes
45
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation,
5-
FusionStyle, matricize, unmatricize
6+
FusionStyle, matricize, tensor_product_axis, unmatricize
67

78
struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle
89
a::A
@@ -11,33 +12,49 @@ end
1112
KroneckerArrays.kroneckerfactors(style::KroneckerFusion) = (style.a, style.b)
1213
KroneckerArrays.kroneckerfactortypes(::Type{KroneckerFusion{A, B}}) where {A, B} = (A, B)
1314

14-
function TensorAlgebra.FusionStyle(a::AbstractKroneckerArray)
15-
return KroneckerFusion(FusionStyle.(kroneckerfactors(a))...)
15+
function TensorAlgebra.FusionStyle(A::Type{<:AbstractKroneckerArray})
16+
return KroneckerFusion(FusionStyle.(kroneckerfactortypes(A))...)
1617
end
18+
function TensorAlgebra.FusionStyle(A::Type{<:CartesianProductUnitRange})
19+
return KroneckerFusion(FusionStyle.(kroneckerfactortypes(A))...)
20+
end
21+
22+
function TensorAlgebra.tensor_product_axis(
23+
style::KroneckerFusion, r1::AbstractUnitRange, r2::AbstractUnitRange
24+
)
25+
style_a, style_b = kroneckerfactors(style)
26+
r1a, r1b = kroneckerfactors(r1)
27+
r2a, r2b = kroneckerfactors(r2)
28+
ra = tensor_product_axis(style_a, r1a, r2a)
29+
rb = tensor_product_axis(style_b, r1b, r2b)
30+
return cartesianrange(ra, rb)
31+
end
32+
1733
function matricize_kronecker(
18-
style::FusionStyle, a::AbstractArray, length1::Val, length2::Val
34+
style::FusionStyle, a::AbstractArray, length_codomain::Val
1935
)
20-
m1 = matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), length1, length2)
21-
m2 = matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), length1, length2)
36+
m1 = matricize(kroneckerfactors(style, 1), kroneckerfactors(a, 1), length_codomain)
37+
m2 = matricize(kroneckerfactors(style, 2), kroneckerfactors(a, 2), length_codomain)
2238
return m1 m2
2339
end
2440
function TensorAlgebra.matricize(
25-
style::KroneckerFusion, a::AbstractArray, length1::Val, length2::Val
41+
style::KroneckerFusion, a::AbstractArray, length_codomain::Val
2642
)
27-
return matricize_kronecker(style, a, length1, length2)
43+
return matricize_kronecker(style, a, length_codomain)
2844
end
45+
2946
function unmatricize_kronecker(
3047
style::FusionStyle,
3148
m::AbstractMatrix,
32-
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
33-
domain_axes::Tuple{Vararg{AbstractUnitRange}},
49+
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
50+
axes_domain::Tuple{Vararg{AbstractUnitRange}},
3451
)
3552
style1, style2 = kroneckerfactors(style)
3653
m1, m2 = kroneckerfactors(m)
37-
codomain1 = kroneckerfactors.(codomain_axes, 1)
38-
codomain2 = kroneckerfactors.(codomain_axes, 2)
39-
domain1 = kroneckerfactors.(domain_axes, 1)
40-
domain2 = kroneckerfactors.(domain_axes, 2)
54+
codomain1 = kroneckerfactors.(axes_codomain, 1)
55+
codomain2 = kroneckerfactors.(axes_codomain, 2)
56+
domain1 = kroneckerfactors.(axes_domain, 1)
57+
domain2 = kroneckerfactors.(axes_domain, 2)
4158
a1 = unmatricize(style1, m1, codomain1, domain1)
4259
a2 = unmatricize(style2, m2, codomain2, domain2)
4360
return a1 a2

ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl

Lines changed: 0 additions & 14 deletions
This file was deleted.

test/Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1515
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1616
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1717
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
18-
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
1918
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2019
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
2120

@@ -38,7 +37,6 @@ MatrixAlgebraKit = "0.6"
3837
SafeTestsets = "0.1"
3938
StableRNGs = "1.0"
4039
Suppressor = "0.2"
41-
TensorAlgebra = "0.5"
42-
TensorProducts = "0.1.7"
40+
TensorAlgebra = "0.6.2"
4341
Test = "1.10"
4442
TestExtras = "0.3"

test/test_tensoralgebra.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
1-
using TensorAlgebra: matricize, unmatricize
2-
using KroneckerArrays: , kroneckerfactors
1+
using TensorAlgebra: matricize, tensor_product_axis, unmatricize
2+
using KroneckerArrays: , cartesianrange, kroneckerfactors, unproduct
33
using Test: @test, @testset
44

55
@testset "TensorAlgebraExt" begin
6-
a = randn(2, 2, 2) randn(3, 3, 3)
7-
m = matricize(a, (1, 2), (3,))
8-
@test m == matricize(kroneckerfactors(a, 1), (1, 2), (3,)) matricize(kroneckerfactors(a, 2), (1, 2), (3,))
9-
@test unmatricize(m, (axes(a, 1), axes(a, 2)), (axes(a, 3),)) == a
6+
@testset "tensor_product_axis" begin
7+
r1 = cartesianrange(2, 3)
8+
r2 = cartesianrange(4, 5)
9+
r = tensor_product_axis(r1, r2)
10+
@test r cartesianrange(8, 15)
11+
@test kroneckerfactors(r, 1) Base.OneTo(8)
12+
@test kroneckerfactors(r, 2) Base.OneTo(15)
13+
@test unproduct(r) Base.OneTo(120)
14+
end
15+
@testset "matricize/unmatricize" begin
16+
a = randn(2, 2, 2) randn(3, 3, 3)
17+
m = matricize(a, (1, 2), (3,))
18+
@test m == matricize(kroneckerfactors(a, 1), (1, 2), (3,)) matricize(kroneckerfactors(a, 2), (1, 2), (3,))
19+
@test unmatricize(m, (axes(a, 1), axes(a, 2)), (axes(a, 3),)) == a
20+
end
1021
end

test/test_tensorproducts.jl

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)