Skip to content

Commit 69a8322

Browse files
authored
Generalize blockedperm ellipsis inputs, change constructor names (#27)
1 parent f67770f commit 69a8322

9 files changed

+139
-236
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.11"
4+
version = "0.2.0"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/blockedpermutation.jl

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,50 +32,60 @@ widened_constructorof(::Type{<:AbstractBlockPermutation}) = BlockedTuple
3232
# blockperm((4, 3, 2, 1), (2, 2)) == blockedperm((4, 3), (2, 1))
3333
# TODO: Optimize with StaticNumbers.jl or generated functions, see:
3434
# https://discourse.julialang.org/t/avoiding-type-instability-when-slicing-a-tuple/38567
35-
function blockperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}})
35+
function blockedperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}})
3636
return blockedperm(BlockedTuple(perm, blocklengths))
3737
end
3838

39-
function blockperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val)
39+
function blockedperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val)
4040
return blockedperm(BlockedTuple(perm, BlockLengths))
4141
end
4242

43-
function Base.invperm(blockedperm::AbstractBlockPermutation)
43+
function Base.invperm(bp::AbstractBlockPermutation)
4444
# use Val to preserve compile time info
45-
return blockperm(invperm(Tuple(blockedperm)), Val(blocklengths(blockedperm)))
45+
return blockedperm(invperm(Tuple(bp)), Val(blocklengths(bp)))
4646
end
4747

4848
#
4949
# Constructors
5050
#
5151

52+
function blockedperm(bt::AbstractBlockTuple)
53+
return permmortar(blocks(bt))
54+
end
55+
5256
# Bipartition a vector according to the
5357
# bipartitioned permutation.
5458
# Like `Base.permute!` block out-of-place and blocked.
5559
function blockpermute(v, blockedperm::AbstractBlockPermutation)
5660
return map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm))
5761
end
5862

59-
# blockedperm((4, 3), (2, 1))
60-
function blockedperm(permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing)
61-
return blockedperm(length, permblocks...)
63+
# blockedpermvcat((4, 3), (2, 1))
64+
function blockedpermvcat(
65+
permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing
66+
)
67+
return blockedpermvcat(length, permblocks...)
6268
end
6369

64-
function blockedperm(::Nothing, permblocks::Tuple{Vararg{Int}}...)
65-
return blockedperm(Val(sum(length, permblocks; init=zero(Bool))), permblocks...)
70+
function blockedpermvcat(::Nothing, permblocks::Tuple{Vararg{Int}}...)
71+
return blockedpermvcat(Val(sum(length, permblocks; init=zero(Bool))), permblocks...)
6672
end
6773

68-
# blockedperm((3, 2), 1) == blockedperm((3, 2), (1,))
69-
function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int}...; kwargs...)
70-
return blockedperm(collect_tuple.(permblocks)...; kwargs...)
74+
# blockedpermvcat((3, 2), 1) == blockedpermvcat((3, 2), (1,))
75+
function blockedpermvcat(permblocks::Union{Tuple{Vararg{Int}},Int}...; kwargs...)
76+
return blockedpermvcat(collect_tuple.(permblocks)...; kwargs...)
7177
end
7278

73-
function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int,Ellipsis}...; kwargs...)
74-
return blockedperm(collect_tuple.(permblocks)...; kwargs...)
79+
function blockedpermvcat(
80+
permblocks::Union{Tuple{Vararg{Int}},Tuple{Ellipsis},Int,Ellipsis}...; kwargs...
81+
)
82+
return blockedpermvcat(collect_tuple.(permblocks)...; kwargs...)
7583
end
7684

77-
function blockedperm(bt::AbstractBlockTuple)
78-
return blockedperm(Val(length(bt)), blocks(bt)...)
85+
function blockedpermvcat(len::Val, permblocks::Tuple{Vararg{Int}}...)
86+
value(len) != sum(length.(permblocks); init=0) &&
87+
throw(ArgumentError("Invalid total length"))
88+
return permmortar(Tuple(permblocks))
7989
end
8090

8191
function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}})
@@ -86,25 +96,39 @@ function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}})
8696
return value(vallength)
8797
end
8898

89-
# blockedperm((4, 3), .., 1) == blockedperm((4, 3), 2, 1)
90-
# blockedperm((4, 3), .., 1; length=Val(5)) == blockedperm((4, 3), 2, 5, 1)
91-
function blockedperm(
92-
permblocks::Union{Tuple{Vararg{Int}},Ellipsis}...; length::Union{Val,Nothing}=nothing
99+
# blockedpermvcat((4, 3), .., 1) == blockedpermvcat((4, 3), (2,), (1,))
100+
# blockedpermvcat((4, 3), .., 1; length=Val(5)) == blockedpermvcat((4, 3), (2,), (5,), (1,))
101+
# blockedpermvcat((4, 3), (..,), 1) == blockedpermvcat((4, 3), (2,), (1,))
102+
# blockedpermvcat((4, 3), (..,), 1; length=Val(5)) == blockedpermvcat((4, 3), (2, 5), (1,))
103+
function blockedpermvcat(
104+
permblocks::Union{Tuple{Vararg{Int}},Ellipsis,Tuple{Ellipsis}}...;
105+
length::Union{Val,Nothing}=nothing,
93106
)
94107
# Check there is only one `Ellipsis`.
95-
@assert isone(count(x -> x isa Ellipsis, permblocks))
96-
specified_permblocks = filter(x -> !(x isa Ellipsis), permblocks)
97-
unspecified_dim = findfirst(x -> x isa Ellipsis, permblocks)
108+
@assert isone(count(x -> x isa Union{Ellipsis,Tuple{Ellipsis}}, permblocks))
109+
specified_permblocks = filter(x -> !(x isa Union{Ellipsis,Tuple{Ellipsis}}), permblocks)
110+
unspecified_dim = findfirst(x -> x isa Union{Ellipsis,Tuple{Ellipsis}}, permblocks)
98111
specified_perm = flatten_tuples(specified_permblocks)
99112
len = _blockedperm_length(length, specified_perm)
100-
unspecified_dims = Tuple(setdiff(Base.OneTo(len), flatten_tuples(specified_permblocks)))
101-
permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, unspecified_dims)
102-
return blockedperm(permblocks_specified...)
113+
unspecified_dims_vec = setdiff(Base.OneTo(len), specified_perm)
114+
ndims_unspecified = Val(len - sum(Base.length.(specified_permblocks))) # preserve type stability when possible
115+
insert = unspecified_dims(
116+
permblocks[unspecified_dim], unspecified_dims_vec, ndims_unspecified
117+
)
118+
permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, insert)
119+
return blockedpermvcat(permblocks_specified...)
120+
end
121+
122+
function unspecified_dims(::Tuple{Ellipsis}, unspecified_dims_vec, ndims_unspecified::Val)
123+
return (ntuple(i -> unspecified_dims_vec[i], ndims_unspecified),)
124+
end
125+
function unspecified_dims(::Ellipsis, unspecified_dims_vec, ndims_unspecified::Val)
126+
return ntuple(i -> (unspecified_dims_vec[i],), ndims_unspecified)
103127
end
104128

105129
# Version of `indexin` that outputs a `blockedperm`.
106130
function blockedperm_indexin(collection, subs...)
107-
return blockedperm(map(sub -> BaseExtensions.indexin(sub, collection), subs)...)
131+
return blockedpermvcat(map(sub -> BaseExtensions.indexin(sub, collection), subs)...)
108132
end
109133

110134
#
@@ -138,7 +162,7 @@ function BlockArrays.blocklengths(
138162
return BlockLengths
139163
end
140164

141-
function blockedperm(::Val, permblocks::Tuple{Vararg{Int}}...)
165+
function permmortar(permblocks::Tuple{Vararg{Tuple{Vararg{Int}}}})
142166
blockedperm = BlockedPermutation{length(permblocks),length.(permblocks)}(
143167
flatten_tuples(permblocks)
144168
)

src/contract/blockedperms.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
2222
perm_domain2 = BaseExtensions.indexin(domain, dimnames2)
2323

2424
permblocks_dest = (perm_codomain_dest, perm_domain_dest)
25-
biperm_dest = blockedperm(filter(!isempty, permblocks_dest)...)
25+
biperm_dest = blockedpermvcat(filter(!isempty, permblocks_dest)...)
2626
permblocks1 = (perm_codomain1, perm_domain1)
27-
biperm1 = blockedperm(filter(!isempty, permblocks1)...)
27+
biperm1 = blockedpermvcat(filter(!isempty, permblocks1)...)
2828
permblocks2 = (perm_codomain2, perm_domain2)
29-
biperm2 = blockedperm(filter(!isempty, permblocks2)...)
29+
biperm2 = blockedpermvcat(filter(!isempty, permblocks2)...)
3030
return biperm_dest, biperm1, biperm2
3131
end

src/fusedims.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,8 @@ end
4545
# Fix ambiguity issue
4646
fusedims(a::AbstractArray{<:Any,0}, ::Vararg{Tuple{}}) = a
4747

48-
# TODO: Is this needed? Maybe delete.
4948
function fusedims(a::AbstractArray, permblocks...)
50-
return fusedims(a, blockedperm(permblocks...; length=Val(ndims(a))))
49+
return fusedims(a, blockedpermvcat(permblocks...; length=Val(ndims(a))))
5150
end
5251

5352
function fuseaxes(

test/Project.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
4-
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
54
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
65
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
76
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
@@ -10,7 +9,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
109
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1110
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1211
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
13-
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1412
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1513
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1614
SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e"
@@ -21,10 +19,8 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
2119

2220
[compat]
2321
Aqua = "0.8.9"
24-
BlockSparseArrays = "0.2"
2522
Random = "1.10"
2623
SafeTestsets = "0.1"
27-
SparseArraysBase = "0.2.11"
2824
Suppressor = "0.2"
2925
SymmetrySectors = "0.1"
3026
TensorOperations = "5.1.3"

test/test_basics.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
3232
a_fused = fusedims(a, (3, 1), .., 2)
3333
@test eltype(a_fused) === elt
3434
@test a_fused reshape(permutedims(a, (3, 1, 4, 2)), (8, 5, 3))
35-
a_fused = fusedims(a, (3, 1), ..)
35+
a_fused = fusedims(a, (3, 1), (..,))
3636
@test eltype(a_fused) === elt
37-
@test a_fused reshape(permutedims(a, (3, 1, 2, 4)), (8, 3, 5))
37+
@test a_fused reshape(permutedims(a, (3, 1, 2, 4)), (8, 15))
3838
end
3939
@testset "splitdims (eltype=$elt)" for elt in elts
4040
a = randn(elt, 6, 20)

test/test_blockarrays_contract.jl

Lines changed: 19 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
using BlockArrays: Block, BlockArray, BlockedArray, blockedrange, blocksize
2-
using BlockSparseArrays: BlockSparseArray
3-
using SparseArraysBase: densearray
42
using TensorAlgebra: contract
53
using Random: randn!
64
using Test: @test, @test_broken, @testset
75

86
function randn_blockdiagonal(elt::Type, axes::Tuple)
9-
a = BlockSparseArray{elt}(axes)
7+
a = zeros(elt, axes)
108
blockdiaglength = minimum(blocksize(a))
119
for i in 1:blockdiaglength
1210
b = Block(ntuple(Returns(i), ndims(a)))
@@ -18,74 +16,14 @@ end
1816
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
1917
@testset "`contract` blocked arrays (eltype=$elt)" for elt in elts
2018
d = blockedrange([2, 3])
21-
a1_sba = randn_blockdiagonal(elt, (d, d, d, d))
22-
a2_sba = randn_blockdiagonal(elt, (d, d, d, d))
23-
a3_sba = randn_blockdiagonal(elt, (d, d))
24-
a1_dense = densearray(a1_sba)
25-
a2_dense = densearray(a2_sba)
26-
a3_dense = densearray(a3_sba)
27-
28-
@testset "BlockArray" begin
29-
a1 = BlockArray(a1_sba)
30-
a2 = BlockArray(a2_sba)
31-
a3 = BlockArray(a3_sba)
32-
33-
# matrix matrix
34-
@test_broken a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4))
35-
#=
36-
a_dest_dense, dimnames_dest_dense = contract(
37-
a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4)
38-
)
39-
@test dimnames_dest == dimnames_dest_dense
40-
@test size(a_dest) == size(a_dest_dense)
41-
@test a_dest isa BlockArray
42-
@test a_dest ≈ a_dest_dense
43-
=#
44-
45-
# matrix vector
46-
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
47-
#=
48-
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
49-
@test dimnames_dest == dimnames_dest_dense
50-
@test size(a_dest) == size(a_dest_dense)
51-
@test a_dest isa BlockArray
52-
@test a_dest ≈ a_dest_dense
53-
=#
54-
55-
# vector matrix
56-
@test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
57-
#=
58-
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1))
59-
@test dimnames_dest == dimnames_dest_dense
60-
@test size(a_dest) == size(a_dest_dense)
61-
@test a_dest isa BlockArray
62-
@test a_dest ≈ a_dest_dense
63-
=#
64-
65-
# vector vector
66-
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1))
67-
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
68-
@test dimnames_dest == dimnames_dest_dense
69-
@test size(a_dest) == size(a_dest_dense)
70-
@test_broken a_dest isa BlockArray # TBD relax to AbstractArray{elt,0}?
71-
@test a_dest a_dest_dense
72-
73-
# outer product
74-
@test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4))
75-
#=
76-
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4))
77-
@test dimnames_dest == dimnames_dest_dense
78-
@test size(a_dest) == size(a_dest_dense)
79-
@test a_dest isa BlockArray
80-
@test a_dest ≈ a_dest_dense
81-
=#
82-
end
19+
a1 = randn_blockdiagonal(elt, (d, d, d, d))
20+
a2 = randn_blockdiagonal(elt, (d, d, d, d))
21+
a3 = randn_blockdiagonal(elt, (d, d))
22+
a1_dense = convert(Array, a1)
23+
a2_dense = convert(Array, a2)
24+
a3_dense = convert(Array, a3)
8325

8426
@testset "BlockedArray" begin
85-
a1 = BlockedArray(a1_sba)
86-
a2 = BlockedArray(a2_sba)
87-
a3 = BlockedArray(a3_sba)
88-
8927
# matrix matrix
9028
a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4))
9129
a_dest_dense, dimnames_dest_dense = contract(
@@ -97,31 +35,27 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
9735
@test a_dest a_dest_dense
9836

9937
# matrix vector
100-
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
101-
#=
38+
a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
10239
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
10340
@test dimnames_dest == dimnames_dest_dense
10441
@test size(a_dest) == size(a_dest_dense)
10542
@test a_dest isa BlockedArray
10643
@test a_dest a_dest_dense
107-
=#
10844

10945
# vector matrix
110-
@test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
111-
#=
46+
a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
11247
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1))
11348
@test dimnames_dest == dimnames_dest_dense
11449
@test size(a_dest) == size(a_dest_dense)
11550
@test a_dest isa BlockedArray
11651
@test a_dest a_dest_dense
117-
=#
11852

11953
# vector vector
12054
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
12155
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1))
12256
@test dimnames_dest == dimnames_dest_dense
12357
@test size(a_dest) == size(a_dest_dense)
124-
@test_broken a_dest isa BlockedArray # TBD relax to AbstractArray{elt,0}?
58+
@test_broken a_dest isa BlockedArray{elt,0}
12559
@test a_dest a_dest_dense
12660

12761
# outer product
@@ -133,8 +67,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
13367
@test a_dest a_dest_dense
13468
end
13569

136-
@testset "BlockSparseArray" begin
137-
a1, a2, a3 = a1_sba, a2_sba, a3_sba
70+
@testset "BlockArray" begin
71+
a1, a3, a3 = BlockArray.((a1, a2, a3))
13872

13973
# matrix matrix
14074
a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4))
@@ -143,41 +77,39 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
14377
)
14478
@test dimnames_dest == dimnames_dest_dense
14579
@test size(a_dest) == size(a_dest_dense)
146-
@test a_dest isa BlockSparseArray
80+
@test a_dest isa BlockArray
14781
@test a_dest a_dest_dense
14882

14983
# matrix vector
150-
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
151-
#=
84+
a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
15285
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
15386
@test dimnames_dest == dimnames_dest_dense
15487
@test size(a_dest) == size(a_dest_dense)
155-
@test a_dest isa BlockSparseArray
88+
@test a_dest isa BlockArray
15689
@test a_dest a_dest_dense
157-
=#
15890

15991
# vector matrix
16092
a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
16193
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1))
16294
@test dimnames_dest == dimnames_dest_dense
16395
@test size(a_dest) == size(a_dest_dense)
164-
@test a_dest isa BlockSparseArray
96+
@test a_dest isa BlockArray
16597
@test a_dest a_dest_dense
16698

16799
# vector vector
168100
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1))
169101
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
170102
@test dimnames_dest == dimnames_dest_dense
171103
@test size(a_dest) == size(a_dest_dense)
172-
@test a_dest isa BlockSparseArray
104+
@test_broken a_dest isa BlockArray{elt,0}
173105
@test a_dest a_dest_dense
174106

175107
# outer product
176-
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4))
177108
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4))
109+
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4))
178110
@test dimnames_dest == dimnames_dest_dense
179111
@test size(a_dest) == size(a_dest_dense)
180-
@test a_dest isa BlockSparseArray
112+
@test a_dest isa BlockArray
181113
@test a_dest a_dest_dense
182114
end
183115
end

0 commit comments

Comments
 (0)