Skip to content

Commit 10f6e1a

Browse files
authored
Refactor matricize/unmatricize (#96)
1 parent 5c7a9bb commit 10f6e1a

File tree

8 files changed

+252
-74
lines changed

8 files changed

+252
-74
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
3-
version = "0.4.6"
43
authors = ["ITensor developers <[email protected]> and contributors"]
4+
version = "0.5.0"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -22,11 +22,11 @@ TensorAlgebraTensorOperationsExt = "TensorOperations"
2222
[compat]
2323
ArrayLayouts = "1.10.4"
2424
BlockArrays = "1.7.2"
25-
EllipsisNotation = "1.8.0"
25+
EllipsisNotation = "1.8"
2626
LinearAlgebra = "1.10"
2727
MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5, 0.6"
2828
TensorOperations = "5"
2929
TensorProducts = "0.1.5"
30-
TupleTools = "1.6.0"
30+
TupleTools = "1.6"
3131
TypeParameterAccessors = "0.2.1, 0.3, 0.4"
3232
julia = "1.10"

docs/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
44
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
55

6+
[sources]
7+
TensorAlgebra = {path = ".."}
8+
69
[compat]
710
Documenter = "1.8.1"
811
Literate = "2.20.1"
9-
TensorAlgebra = "0.4"
12+
TensorAlgebra = "0.5"

examples/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
[deps]
22
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33

4+
[sources]
5+
TensorAlgebra = {path = ".."}
6+
47
[compat]
5-
TensorAlgebra = "0.4"
8+
TensorAlgebra = "0.5"

src/blockedpermutation.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function istrivialperm(t::Tuple)
88
return t == trivialperm(length(t))
99
end
1010

11-
value(::Val{N}) where {N} = N
11+
unval(::Val{N}) where {N} = N
1212

1313
_flatten_tuples(t::Tuple) = t
1414
function _flatten_tuples(t1::Tuple, t2::Tuple, trest::Tuple...)
@@ -87,7 +87,7 @@ function blockedpermvcat(
8787
end
8888

8989
function blockedpermvcat(len::Val, permblocks::Tuple{Vararg{Int}}...)
90-
value(len) != sum(length.(permblocks); init = 0) &&
90+
unval(len) != sum(length.(permblocks); init = 0) &&
9191
throw(ArgumentError("Invalid total length"))
9292
return permmortar(Tuple(permblocks))
9393
end
@@ -97,7 +97,7 @@ function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}})
9797
end
9898

9999
function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}})
100-
return value(vallength)
100+
return unval(vallength)
101101
end
102102

103103
# blockedpermvcat((4, 3), .., 1) == blockedpermvcat((4, 3), (2,), (1,))
@@ -199,8 +199,11 @@ end
199199

200200
blockedperm(tp::BlockedTrivialPermutation) = tp
201201

202+
function blockedtrivialperm(blocklengths::Tuple{Vararg{Val}})
203+
return BlockedTrivialPermutation{length(blocklengths), unval.(blocklengths)}()
204+
end
202205
function blockedtrivialperm(blocklengths::Tuple{Vararg{Int}})
203-
return BlockedTrivialPermutation{length(blocklengths), blocklengths}()
206+
return blockedtrivialperm(Val.(blocklengths))
204207
end
205208

206209
function trivialperm(blockedperm::AbstractBlockTuple)

src/contract/contract.jl

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,29 @@ abstract type Algorithm end
44

55
Algorithm(alg::Algorithm) = alg
66

7-
struct Matricize <: Algorithm end
7+
struct Matricize{Style} <: Algorithm
8+
fusion_style::Style
9+
end
10+
Matricize() = Matricize(ReshapeFusion())
811

9-
default_contract_alg() = Matricize()
12+
function default_contract_alg(a1::AbstractArray, labels1, a2::AbstractArray, labels2)
13+
style1 = FusionStyle(a1)
14+
style2 = FusionStyle(a2)
15+
style1 == style2 || error("Styles must match.")
16+
return Matricize(style1)
17+
end
18+
function default_contractadd!_alg(
19+
a_dest::AbstractArray, labels_dest,
20+
a1::AbstractArray, labels1,
21+
a2::AbstractArray, labels2,
22+
α::Number, β::Number,
23+
)
24+
style_dest = FusionStyle(a_dest)
25+
style1 = FusionStyle(a1)
26+
style2 = FusionStyle(a2)
27+
style_dest == style1 == style2 || error("Styles must match.")
28+
return Matricize(style_dest)
29+
end
1030

1131
# Required interface if not using
1232
# matricized contraction.
@@ -29,7 +49,7 @@ function contract(
2949
labels1,
3050
a2::AbstractArray,
3151
labels2;
32-
alg = default_contract_alg(),
52+
alg = default_contract_alg(a1, labels1, a2, labels2),
3353
kwargs...,
3454
)
3555
return contract(Algorithm(alg), a1, labels1, a2, labels2; kwargs...)
@@ -48,7 +68,7 @@ function contract(
4868
labels1,
4969
a2::AbstractArray,
5070
labels2;
51-
alg = default_contract_alg(),
71+
alg = default_contract_alg(a1, labels1, a2, labels2),
5272
kwargs...,
5373
)
5474
return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2; kwargs...)
@@ -75,7 +95,7 @@ function contractadd!(
7595
labels2,
7696
α::Number,
7797
β::Number;
78-
alg = default_contract_alg(),
98+
alg = default_contractadd!_alg(a_dest, labels_dest, a1, labels1, a2, labels2, α, β),
7999
kwargs...,
80100
)
81101
contractadd!(
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using LinearAlgebra: mul!
22

33
function contractadd!(
4-
::Matricize,
4+
alg::Matricize,
55
a_dest::AbstractArray,
66
biperm_dest::AbstractBlockPermutation{2},
77
a1::AbstractArray,
@@ -12,11 +12,10 @@ function contractadd!(
1212
β::Number,
1313
)
1414
invbiperm = biperm(invperm(biperm_dest), length_codomain(biperm1))
15-
1615
check_input(contract, a_dest, invbiperm, a1, biperm1, a2, biperm2)
17-
a1_mat = matricize(a1, biperm1)
18-
a2_mat = matricize(a2, biperm2)
16+
a1_mat = matricize(alg.fusion_style, a1, biperm1)
17+
a2_mat = matricize(alg.fusion_style, a2, biperm2)
1918
a_dest_mat = a1_mat * a2_mat
20-
unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β)
19+
unmatricizeadd!(alg.fusion_style, a_dest, a_dest_mat, invbiperm, α, β)
2120
return a_dest
2221
end

0 commit comments

Comments
 (0)