Skip to content

Generalize blockedperm ellipsis inputs, change constructor names #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Mar 2, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.10"
version = "0.1.11"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
42 changes: 30 additions & 12 deletions src/blockedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,17 @@
return blockedperm(collect_tuple.(permblocks)...; kwargs...)
end

function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int,Ellipsis}...; kwargs...)
function blockedperm(

Check warning on line 73 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L73

Added line #L73 was not covered by tests
permblocks::Union{Tuple{Vararg{Int}},Tuple{Ellipsis},Int,Ellipsis}...; kwargs...
)
return blockedperm(collect_tuple.(permblocks)...; kwargs...)
end

function blockedperm(bt::AbstractBlockTuple)
return blockedperm(Val(length(bt)), blocks(bt)...)
# keep len kwarg to be consistent with other method signatures
function blockedperm(bt::AbstractBlockTuple; length::Union{Val,Nothing}=nothing)
!(length ∈ (nothing, Val(Base.length(bt)))) &&

Check warning on line 81 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L80-L81

Added lines #L80 - L81 were not covered by tests
throw(ArgumentError("Invalid total length"))
return blockedperm(Val(Base.length(bt)), blocks(bt)...)

Check warning on line 83 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L83

Added line #L83 was not covered by tests
end

function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}})
Expand All @@ -86,22 +91,34 @@
return value(vallength)
end

# blockedperm((4, 3), .., 1) == blockedperm((4, 3), 2, 1)
# blockedperm((4, 3), .., 1; length=Val(5)) == blockedperm((4, 3), 2, 5, 1)
# blockedperm((4, 3), .., 1) == blockedperm((4, 3), (2,), (1,))
# blockedperm((4, 3), .., 1; length=Val(5)) == blockedperm((4, 3), (2,), (5,), (1,))
# blockedperm((4, 3), (..,), 1) == blockedperm((4, 3), (2,), (1,))
# blockedperm((4, 3), (..,), 1; length=Val(5)) == blockedperm((4, 3), (2, 5), (1,))
function blockedperm(
permblocks::Union{Tuple{Vararg{Int}},Ellipsis}...; length::Union{Val,Nothing}=nothing
permblocks::Union{Tuple{Vararg{Int}},Ellipsis,Tuple{Ellipsis}}...;
length::Union{Val,Nothing}=nothing,
)
# Check there is only one `Ellipsis`.
@assert isone(count(x -> x isa Ellipsis, permblocks))
specified_permblocks = filter(x -> !(x isa Ellipsis), permblocks)
unspecified_dim = findfirst(x -> x isa Ellipsis, permblocks)
@assert isone(count(x -> x isa Union{Ellipsis,Tuple{Ellipsis}}, permblocks))
specified_permblocks = filter(x -> !(x isa Union{Ellipsis,Tuple{Ellipsis}}), permblocks)
unspecified_dim = findfirst(x -> x isa Union{Ellipsis,Tuple{Ellipsis}}, permblocks)

Check warning on line 105 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L103-L105

Added lines #L103 - L105 were not covered by tests
specified_perm = flatten_tuples(specified_permblocks)
len = _blockedperm_length(length, specified_perm)
unspecified_dims = Tuple(setdiff(Base.OneTo(len), flatten_tuples(specified_permblocks)))
permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, unspecified_dims)
unspecified_dims_vec = setdiff(Base.OneTo(len), specified_perm)
UD = Val(len - sum(Base.length.(specified_permblocks))) # preserve type stability when possible
insert = unspecified_dims(typeof(permblocks[unspecified_dim]), unspecified_dims_vec, UD)
permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, insert)

Check warning on line 111 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L108-L111

Added lines #L108 - L111 were not covered by tests
return blockedperm(permblocks_specified...)
end

function unspecified_dims(::Type{Tuple{Ellipsis}}, unspecified_dims_vec, UD::Val)
return (NTuple{value(UD),Int}(unspecified_dims_vec),)

Check warning on line 116 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L115-L116

Added lines #L115 - L116 were not covered by tests
end
function unspecified_dims(::Type{Ellipsis}, unspecified_dims_vec, UD::Val)
return NTuple{value(UD),Tuple{Int}}(Tuple.(unspecified_dims_vec))

Check warning on line 119 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L118-L119

Added lines #L118 - L119 were not covered by tests
end

# Version of `indexin` that outputs a `blockedperm`.
function blockedperm_indexin(collection, subs...)
return blockedperm(map(sub -> BaseExtensions.indexin(sub, collection), subs)...)
Expand Down Expand Up @@ -138,10 +155,11 @@
return BlockLengths
end

function blockedperm(::Val, permblocks::Tuple{Vararg{Int}}...)
function blockedperm(len::Val, permblocks::Tuple{Vararg{Int}}...)

Check warning on line 158 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L158

Added line #L158 was not covered by tests
blockedperm = BlockedPermutation{length(permblocks),length.(permblocks)}(
flatten_tuples(permblocks)
)
value(len) != length(blockedperm) && throw(ArgumentError("Invalid total length"))

Check warning on line 162 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L162

Added line #L162 was not covered by tests
@assert isperm(blockedperm)
return blockedperm
end
Expand Down
4 changes: 2 additions & 2 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a_fused = fusedims(a, (3, 1), .., 2)
@test eltype(a_fused) === elt
@test a_fused ≈ reshape(permutedims(a, (3, 1, 4, 2)), (8, 5, 3))
a_fused = fusedims(a, (3, 1), ..)
a_fused = fusedims(a, (3, 1), (..,))
@test eltype(a_fused) === elt
@test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 3, 5))
@test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15))
end
@testset "splitdims (eltype=$elt)" for elt in elts
a = randn(elt, 6, 20)
Expand Down
55 changes: 40 additions & 15 deletions test/test_blockedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ using TensorAlgebra:
blockedperm,
blockedperm_indexin,
blockedtrivialperm,
trivialperm
trivialperm,
tuplemortar

@testset "BlockedPermutation" begin
p = @constinferred blockedperm((3, 4, 5), (2, 1))
Expand Down Expand Up @@ -63,13 +64,18 @@ using TensorAlgebra:
@test p isa BlockedPermutation{0}

p = blockedperm((3, 2), (), (1,))
bt = BlockedTuple{3,(2, 0, 1)}((3, 2, 1))
bt = tuplemortar(((3, 2), (), (1,)))
@test (@constinferred BlockedTuple(p)) == bt
@test (@constinferred map(identity, p)) == bt
@test (@constinferred p .+ p) == BlockedTuple{3,(2, 0, 1)}((6, 4, 2))
@test (@constinferred p .+ p) == tuplemortar(((6, 4), (), (2,)))
@test (@constinferred blockedperm(p)) == p
@test (@constinferred blockedperm(bt)) == p

@test_throws ArgumentError blockedperm((1, 3), (2, 4); length=Val(6))
@test_throws ArgumentError blockedperm(tuplemortar(((1, 3), (2, 4))); length=Val(5))
@test (@constinferred blockedperm(tuplemortar(((1, 3), (2, 4))); length=Val(4))) ==
blockedperm((1, 3), (2, 4))

# Split collection into `BlockedPermutation`.
p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d"))
@test p == blockedperm((3, 1), (2, 4))
Expand All @@ -80,28 +86,47 @@ using TensorAlgebra:

# First dimensions are unspecified.
p = blockedperm(.., (4, 3))
@test p == blockedperm(1, 2, (4, 3))
@test p == blockedperm((1,), (2,), (4, 3))
# Specify length
p = blockedperm(.., (4, 3); length=Val(6))
@test p == blockedperm(1, 2, 5, 6, (4, 3))
p = @constinferred blockedperm(.., (4, 3); length=Val(6))
@test p == blockedperm((1,), (2,), (5,), (6,), (4, 3))

# Last dimensions are unspecified.
p = blockedperm((4, 3), ..)
@test p == blockedperm((4, 3), 1, 2)
@test p == blockedperm((4, 3), (1,), (2,))
# Specify length
p = blockedperm((4, 3), ..; length=Val(6))
@test p == blockedperm((4, 3), 1, 2, 5, 6)
p = @constinferred blockedperm((4, 3), ..; length=Val(6))
@test p == blockedperm((4, 3), (1,), (2,), (5,), (6,))

# Middle dimensions are unspecified.
p = blockedperm((4, 3), .., 1)
@test p == blockedperm((4, 3), 2, 1)
@test p == blockedperm((4, 3), (2,), (1,))
# Specify length
p = blockedperm((4, 3), .., 1; length=Val(6))
@test p == blockedperm((4, 3), 2, 5, 6, 1)
p = @constinferred blockedperm((4, 3), .., 1; length=Val(6))
@test p == blockedperm((4, 3), (2,), (5,), (6,), (1,))

# No dimensions are unspecified.
p = blockedperm((3, 2), .., 1)
@test p == blockedperm((3, 2), 1)
@test p == blockedperm((3, 2), (1,))

# same with (..,) instead of ..
p = blockedperm((..,), (4, 3))
@test p == blockedperm((1, 2), (4, 3))
p = @constinferred blockedperm((..,), (4, 3); length=Val(6))
@test p == blockedperm((1, 2, 5, 6), (4, 3))

p = blockedperm((4, 3), (..,))
@test p == blockedperm((4, 3), (1, 2))
p = @constinferred blockedperm((4, 3), (..,); length=Val(6))
@test p == blockedperm((4, 3), (1, 2, 5, 6))

p = blockedperm((4, 3), (..,), 1)
@test p == blockedperm((4, 3), (2,), (1,))
p = @constinferred blockedperm((4, 3), (..,), 1; length=Val(6))
@test p == blockedperm((4, 3), (2, 5, 6), (1,))

p = blockedperm((3, 2), (..,), 1)
@test p == blockedperm((3, 2), (), (1,))
end

@testset "BlockedTrivialPermutation" begin
Expand All @@ -113,11 +138,11 @@ end
@test blocklengths(tp) == (2, 0, 1)
@test trivialperm(blockedperm((3, 2), (), (1,))) == tp

bt = BlockedTuple{3,(2, 0, 1)}((1, 2, 3))
bt = tuplemortar(((1, 2), (), (3,)))
@test (@constinferred BlockedTuple(tp)) == bt
@test (@constinferred blocks(tp)) == blocks(bt)
@test (@constinferred map(identity, tp)) == bt
@test (@constinferred tp .+ tp) == BlockedTuple{3,(2, 0, 1)}((2, 4, 6))
@test (@constinferred tp .+ tp) == tuplemortar(((2, 4), (), (6,)))
@test (@constinferred blockedperm(tp)) == tp
@test (@constinferred trivialperm(tp)) == tp
@test (@constinferred trivialperm(bt)) == tp
Expand Down
Loading