Skip to content

Commit 3e02333

Browse files
authored
Overload TensorAlgebra.trivial_axis, better matricize overloads (#64)
1 parent 135f731 commit 3e02333

File tree

3 files changed

+57
-12
lines changed

3 files changed

+57
-12
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.4"
4+
version = "0.3.5"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -34,6 +34,6 @@ GPUArraysCore = "0.2"
3434
LinearAlgebra = "1.10"
3535
MapBroadcast = "0.1.10"
3636
MatrixAlgebraKit = "0.6"
37-
TensorAlgebra = "0.6.2"
37+
TensorAlgebra = "0.6.3"
3838
TypeParameterAccessors = "0.4.2"
3939
julia = "1.10"

ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module KroneckerArraysTensorAlgebraExt
33
using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, CartesianProductUnitRange,
44
, cartesianrange, kroneckerfactors, kroneckerfactortypes
55
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation,
6-
FusionStyle, matricize, tensor_product_axis, unmatricize
6+
FusionStyle, matricize, tensor_product_axis, trivial_axis, unmatricize
77

88
struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle
99
a::A
@@ -19,14 +19,57 @@ function TensorAlgebra.FusionStyle(A::Type{<:CartesianProductUnitRange})
1919
return KroneckerFusion(FusionStyle.(kroneckerfactortypes(A))...)
2020
end
2121

22+
function TensorAlgebra.trivial_axis(
23+
style::KroneckerFusion, side::Val{:codomain}, a::AbstractArray,
24+
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
25+
axes_domain::Tuple{Vararg{AbstractUnitRange}},
26+
)
27+
return trivial_kronecker(style, side, a, axes_codomain, axes_domain)
28+
end
29+
function TensorAlgebra.trivial_axis(
30+
style::KroneckerFusion, side::Val{:domain}, a::AbstractArray,
31+
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
32+
axes_domain::Tuple{Vararg{AbstractUnitRange}},
33+
)
34+
return trivial_kronecker(style, side, a, axes_codomain, axes_domain)
35+
end
36+
function trivial_kronecker(
37+
style::FusionStyle, side::Val, a::AbstractArray,
38+
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
39+
axes_domain::Tuple{Vararg{AbstractUnitRange}},
40+
)
41+
style_a, style_b = kroneckerfactors(style)
42+
a_a, a_b = kroneckerfactors(a)
43+
axes_codomain_a = kroneckerfactors.(axes_codomain, 1)
44+
axes_codomain_b = kroneckerfactors.(axes_codomain, 2)
45+
axes_domain_a = kroneckerfactors.(axes_domain, 1)
46+
axes_domain_b = kroneckerfactors.(axes_domain, 2)
47+
ra = trivial_axis(style_a, side, a_a, axes_codomain_a, axes_domain_a)
48+
rb = trivial_axis(style_b, side, a_b, axes_codomain_b, axes_domain_b)
49+
return cartesianrange(ra, rb)
50+
end
51+
2252
function TensorAlgebra.tensor_product_axis(
23-
style::KroneckerFusion, r1::AbstractUnitRange, r2::AbstractUnitRange
53+
style::KroneckerFusion, side::Val{:codomain},
54+
r1::AbstractUnitRange, r2::AbstractUnitRange,
55+
)
56+
return tensor_product_kronecker(style, side, r1, r2)
57+
end
58+
function TensorAlgebra.tensor_product_axis(
59+
style::KroneckerFusion, side::Val{:domain},
60+
r1::AbstractUnitRange, r2::AbstractUnitRange,
61+
)
62+
return tensor_product_kronecker(style, side, r1, r2)
63+
end
64+
function tensor_product_kronecker(
65+
style::KroneckerFusion, side::Val,
66+
r1::AbstractUnitRange, r2::AbstractUnitRange,
2467
)
2568
style_a, style_b = kroneckerfactors(style)
2669
r1a, r1b = kroneckerfactors(r1)
2770
r2a, r2b = kroneckerfactors(r2)
28-
ra = tensor_product_axis(style_a, r1a, r2a)
29-
rb = tensor_product_axis(style_b, r1b, r2b)
71+
ra = tensor_product_axis(style_a, side, r1a, r2a)
72+
rb = tensor_product_axis(style_b, side, r1b, r2b)
3073
return cartesianrange(ra, rb)
3174
end
3275

@@ -44,8 +87,7 @@ function TensorAlgebra.matricize(
4487
end
4588

4689
function unmatricize_kronecker(
47-
style::FusionStyle,
48-
m::AbstractMatrix,
90+
style::FusionStyle, m::AbstractMatrix,
4991
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
5092
axes_domain::Tuple{Vararg{AbstractUnitRange}},
5193
)
@@ -60,8 +102,7 @@ function unmatricize_kronecker(
60102
return a1 a2
61103
end
62104
function TensorAlgebra.unmatricize(
63-
style::KroneckerFusion,
64-
m::AbstractMatrix,
105+
style::KroneckerFusion, m::AbstractMatrix,
65106
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
66107
domain_axes::Tuple{Vararg{AbstractUnitRange}},
67108
)

test/test_tensoralgebra.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
using TensorAlgebra: matricize, tensor_product_axis, unmatricize
1+
using TensorAlgebra: matricize, tensor_product_axis, trivial_axis, unmatricize
22
using KroneckerArrays: , cartesianrange, kroneckerfactors, unproduct
33
using Test: @test, @testset
44

55
@testset "TensorAlgebraExt" begin
66
@testset "tensor_product_axis" begin
7+
r = cartesianrange(2, 3)
8+
@test trivial_axis(r) cartesianrange(1, 1)
9+
710
r1 = cartesianrange(2, 3)
811
r2 = cartesianrange(4, 5)
912
r = tensor_product_axis(r1, r2)
@@ -15,7 +18,8 @@ using Test: @test, @testset
1518
@testset "matricize/unmatricize" begin
1619
a = randn(2, 2, 2) randn(3, 3, 3)
1720
m = matricize(a, (1, 2), (3,))
18-
@test m == matricize(kroneckerfactors(a, 1), (1, 2), (3,)) matricize(kroneckerfactors(a, 2), (1, 2), (3,))
21+
@test m == matricize(kroneckerfactors(a, 1), (1, 2), (3,))
22+
matricize(kroneckerfactors(a, 2), (1, 2), (3,))
1923
@test unmatricize(m, (axes(a, 1), axes(a, 2)), (axes(a, 3),)) == a
2024
end
2125
end

0 commit comments

Comments
 (0)