Skip to content

Make better use of BlockedTuple in contract logic to track codomain and domain #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.7"
version = "0.2.8"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
6 changes: 3 additions & 3 deletions src/contract/blockedperms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
perm_domain2 = BaseExtensions.indexin(domain, dimnames2)

permblocks_dest = (perm_codomain_dest, perm_domain_dest)
biperm_dest = blockedpermvcat(filter(!isempty, permblocks_dest)...)
biperm_dest = blockedpermvcat(permblocks_dest...)

Check warning on line 25 in src/contract/blockedperms.jl

View check run for this annotation

Codecov / codecov/patch

src/contract/blockedperms.jl#L25

Added line #L25 was not covered by tests
permblocks1 = (perm_codomain1, perm_domain1)
biperm1 = blockedpermvcat(filter(!isempty, permblocks1)...)
biperm1 = blockedpermvcat(permblocks1...)

Check warning on line 27 in src/contract/blockedperms.jl

View check run for this annotation

Codecov / codecov/patch

src/contract/blockedperms.jl#L27

Added line #L27 was not covered by tests
permblocks2 = (perm_codomain2, perm_domain2)
biperm2 = blockedpermvcat(filter(!isempty, permblocks2)...)
biperm2 = blockedpermvcat(permblocks2...)

Check warning on line 29 in src/contract/blockedperms.jl

View check run for this annotation

Codecov / codecov/patch

src/contract/blockedperms.jl#L29

Added line #L29 was not covered by tests
return biperm_dest, biperm1, biperm2
end
12 changes: 6 additions & 6 deletions src/contract/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ default_contract_alg() = Matricize()
function contract!(
alg::Algorithm,
a_dest::AbstractArray,
biperm_dest::BlockedPermutation,
biperm_dest::AbstractBlockPermutation{2},
a1::AbstractArray,
biperm1::BlockedPermutation,
biperm1::AbstractBlockPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation,
biperm2::AbstractBlockPermutation{2},
α::Number,
β::Number,
)
Expand Down Expand Up @@ -110,11 +110,11 @@ end

function contract(
alg::Algorithm,
biperm_dest::BlockedPermutation,
biperm_dest::AbstractBlockPermutation{2},
a1::AbstractArray,
biperm1::BlockedPermutation,
biperm1::AbstractBlockPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation,
biperm2::AbstractBlockPermutation{2},
α::Number;
kwargs...,
)
Expand Down
95 changes: 7 additions & 88 deletions src/contract/contract_matricize/contract.jl
Original file line number Diff line number Diff line change
@@ -1,103 +1,22 @@
using LinearAlgebra: mul!

function contract!(
alg::Matricize,
::Matricize,
a_dest::AbstractArray,
biperm_dest::BlockedPermutation,
biperm_dest::AbstractBlockPermutation{2},
a1::AbstractArray,
biperm1::BlockedPermutation,
biperm1::AbstractBlockPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation,
biperm2::AbstractBlockPermutation{2},
α::Number,
β::Number,
)
a_dest_mat = fusedims(a_dest, biperm_dest)
a1_mat = fusedims(a1, biperm1)
a2_mat = fusedims(a2, biperm2)
_mul!(a_dest_mat, a1_mat, a2_mat, α, β)
@assert ndims(a1_mat) == 2
@assert ndims(a2_mat) == 2
mul!(a_dest_mat, a1_mat, a2_mat, α, β)

Check warning on line 19 in src/contract/contract_matricize/contract.jl

View check run for this annotation

Codecov / codecov/patch

src/contract/contract_matricize/contract.jl#L17-L19

Added lines #L17 - L19 were not covered by tests
splitdims!(a_dest, a_dest_mat, biperm_dest)
return a_dest
end

# Matrix multiplication.
function _mul!(
a_dest::AbstractMatrix, a1::AbstractMatrix, a2::AbstractMatrix, α::Number, β::Number
)
mul!(a_dest, a1, a2, α, β)
return a_dest
end

# Inner product.
function _mul!(
a_dest::AbstractArray{<:Any,0},
a1::AbstractVector,
a2::AbstractVector,
α::Number,
β::Number,
)
a_dest[] = transpose(a1) * a2 * α + a_dest[] * β
return a_dest
end

# Vec-mat.
function _mul!(
a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number
)
mul!(transpose(a_dest), transpose(a1), a2, α, β)
return a_dest
end

# Mat-vec.
function _mul!(
a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number
)
mul!(a_dest, a1, a2, α, β)
return a_dest
end

# Outer product.
function _mul!(
a_dest::AbstractMatrix, a1::AbstractVector, a2::AbstractVector, α::Number, β::Number
)
mul!(a_dest, a1, transpose(a2), α, β)
return a_dest
end

# Array-scalar contraction.
function _mul!(
a_dest::AbstractVector,
a1::AbstractVector,
a2::AbstractArray{<:Any,0},
α::Number,
β::Number,
)
α′ = a2[] * α
a_dest .= a1 .* α′ .+ a_dest .* β
return a_dest
end

# Scalar-array contraction.
function _mul!(
a_dest::AbstractVector,
a1::AbstractArray{<:Any,0},
a2::AbstractVector,
α::Number,
β::Number,
)
# Preserve the ordering in case of non-commutative algebra.
a_dest .= a1[] .* a2 .* α .+ a_dest .* β
return a_dest
end

# Scalar-scalar contraction.
function _mul!(
a_dest::AbstractArray{<:Any,0},
a1::AbstractArray{<:Any,0},
a2::AbstractArray{<:Any,0},
α::Number,
β::Number,
)
# Preserve the ordering in case of non-commutative algebra.
a_dest[] = a1[] * a2[] * α + a_dest[] * β
return a_dest
end
22 changes: 12 additions & 10 deletions src/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
return qr(A, biperm; kwargs...)
end
function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...)
function qr(

Check warning on line 40 in src/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/factorizations.jl#L40

Added line #L40 was not covered by tests
A::AbstractArray, biperm::AbstractBlockPermutation{2}; full::Bool=false, kwargs...
)
# tensor to matrix
A_mat = fusedims(A, biperm)

Expand All @@ -46,8 +48,8 @@

# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_Q = (axes_codomain..., axes(Q, 2))
axes_R = (axes(R, 1), axes_domain...)
axes_Q = tuplemortar((axes_codomain, (axes(q_matricized, 2),)))
axes_R = tuplemortar(((axes(r_matricized, 1),), axes_domain))

Check warning on line 52 in src/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/factorizations.jl#L51-L52

Added lines #L51 - L52 were not covered by tests
return splitdims(Q, axes_Q), splitdims(R, axes_R)
end

Expand Down Expand Up @@ -80,8 +82,8 @@

# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_L = (axes_codomain..., axes(L, ndims(L)))
axes_Q = (axes(Q, 1), axes_domain...)
axes_L = tuplemortar((axes_codomain, (axes(L, ndims(L)),)))
axes_Q = tuplemortar(((axes(Q, 1),), axes_domain))

Check warning on line 86 in src/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/factorizations.jl#L85-L86

Added lines #L85 - L86 were not covered by tests
return splitdims(L, axes_L), splitdims(Q, axes_Q)
end

Expand Down Expand Up @@ -128,7 +130,7 @@

# matrix to tensor
axes_codomain, = blockpermute(axes(A), biperm)
axes_V = (axes_codomain..., axes(V, ndims(V)))
axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),)))

Check warning on line 133 in src/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/factorizations.jl#L133

Added line #L133 was not covered by tests
return D, splitdims(V, axes_V)
end

Expand Down Expand Up @@ -202,8 +204,8 @@

# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_U = (axes_codomain..., axes(U, 2))
axes_Vᴴ = (axes(Vᴴ, 1), axes_domain...)
axes_U = tuplemortar((axes_codomain, (axes(U, 2),)))
axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain))

Check warning on line 208 in src/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/factorizations.jl#L207-L208

Added lines #L207 - L208 were not covered by tests
return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ)
end

Expand Down Expand Up @@ -251,7 +253,7 @@
A_mat = fusedims(A, biperm)
N = left_null!(A_mat; kwargs...)
axes_codomain, _ = blockpermute(axes(A), biperm)
axes_N = (axes_codomain..., axes(N, 2))
axes_N = tuplemortar((axes_codomain, (axes(N, 2),)))

Check warning on line 256 in src/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/factorizations.jl#L256

Added line #L256 was not covered by tests
N_tensor = splitdims(N, axes_N)
return N_tensor
end
Expand Down Expand Up @@ -281,6 +283,6 @@
A_mat = fusedims(A, biperm)
Nᴴ = right_null!(A_mat; kwargs...)
_, axes_domain = blockpermute(axes(A), biperm)
axes_Nᴴ = (axes(Nᴴ, 1), axes_domain...)
axes_Nᴴ = tuplemortar((axes(Nᴴ, 1), (axes_domain,)))

Check warning on line 286 in src/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/factorizations.jl#L286

Added line #L286 was not covered by tests
return splitdims(Nᴴ, axes_Nᴴ)
end
26 changes: 15 additions & 11 deletions src/fusedims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
combine_fusion_styles(style1::FusionStyle, style2::FusionStyle) = ReshapeFusion()
combine_fusion_styles(styles::FusionStyle...) = foldl(combine_fusion_styles, styles)
FusionStyle(axis::AbstractUnitRange) = ReshapeFusion()
FusionStyle(::Tuple{}) = ReshapeFusion()

Check warning on line 15 in src/fusedims.jl

View check run for this annotation

Codecov / codecov/patch

src/fusedims.jl#L15

Added line #L15 was not covered by tests
function FusionStyle(axes::Tuple{Vararg{AbstractUnitRange}})
return combine_fusion_styles(FusionStyle.(axes)...)
end
Expand All @@ -27,7 +28,6 @@
return fusedims(FusionStyle(a), a, ax, axes...)
end

# Overload this version for fusion tensors, array maps, etc.
function fusedims(
a::AbstractArray,
axb::Tuple{Vararg{AbstractUnitRange}},
Expand All @@ -36,13 +36,6 @@
return fusedims(a, flatten_tuples((axb, axesblocks...))...)
end

# Fix ambiguity issue
fusedims(a::AbstractArray{<:Any,0}, ::Vararg{Tuple{}}) = a

function fusedims(a::AbstractArray, permblocks...)
return fusedims(a, blockedpermvcat(permblocks...; length=Val(ndims(a))))
end

function fuseaxes(
axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation
)
Expand All @@ -60,7 +53,18 @@
return fusedims(a, axes_fused)
end

function fusedims(a::AbstractArray, blockedperm::BlockedPermutation)
a_perm = _permutedims(a, Tuple(blockedperm))
return fusedims(a_perm, trivialperm(blockedperm))
# deal with zero-dim case
fusedims(a::AbstractArray{<:Any,0}, t::Tuple{}...) = reshape(a, ntuple(_ -> 1, length(t)))

Check warning on line 57 in src/fusedims.jl

View check run for this annotation

Codecov / codecov/patch

src/fusedims.jl#L57

Added line #L57 was not covered by tests

function fusedims(a::AbstractArray, bt::AbstractBlockTuple)

Check warning on line 59 in src/fusedims.jl

View check run for this annotation

Codecov / codecov/patch

src/fusedims.jl#L59

Added line #L59 was not covered by tests
# TBD define permutedims(::AbstractArray, ::AbstractBlockPermutation)
# TBD remove call to BlockedTrivialPermutation?
a_perm = _permutedims(a, Tuple(bt))
return fusedims(a_perm, trivialperm(bt))

Check warning on line 63 in src/fusedims.jl

View check run for this annotation

Codecov / codecov/patch

src/fusedims.jl#L62-L63

Added lines #L62 - L63 were not covered by tests
end

# fusedims(ones((2,2,2,2)), (3, 1, 2), (4,))
# fusedims(ones((2,2,2,2)), (3, 1, 2), 4)
function fusedims(a::AbstractArray, permblocks...)
return fusedims(a, blockedpermvcat(permblocks...; length=Val(ndims(a))))

Check warning on line 69 in src/fusedims.jl

View check run for this annotation

Codecov / codecov/patch

src/fusedims.jl#L68-L69

Added lines #L68 - L69 were not covered by tests
end
53 changes: 28 additions & 25 deletions src/splitdims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,65 +4,68 @@
to_axis(n::Integer) = Base.OneTo(n)

function blockedaxes(a::AbstractArray, sizeblocks::Pair...)
axes_a = axes(a)
axes_split = tuple.(axes(a))
for (dim, sizeblock) in sizeblocks
# TODO: Handle conversion from length to range!
axes_split = Base.setindex(axes_split, to_axis.(sizeblock), dim)
end
return axes_split
return tuplemortar(axes_split)

Check warning on line 12 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L12

Added line #L12 was not covered by tests
end

# splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2)
function splitdims(::ReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...)
function splitdims(::ReshapeFusion, a::AbstractArray, abt::BlockedTuple)

Check warning on line 15 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L15

Added line #L15 was not covered by tests
# TODO: Add `uncanonicalizedims`.
# TODO: Need `length` since `reshape` doesn't accept `axes`,
# maybe make a `reshape_axes` function.
return reshape(a, length.(axes)...)
return reshape(a, Tuple(length.(abt)))

Check warning on line 19 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L19

Added line #L19 was not covered by tests
end

# ambiguity for zero-dim
function splitdims(a::AbstractArray{<:Any,N}, abt::BlockedTuple{N,<:Any,Tuple{}}) where {N}
return splitdims(FusionStyle(a), a, abt)

Check warning on line 24 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L23-L24

Added lines #L23 - L24 were not covered by tests
end

function splitdims(

Check warning on line 27 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L27

Added line #L27 was not covered by tests
a::AbstractArray{<:Any,N}, bt::BlockedTuple{N,<:Any,<:Tuple{Vararg{AbstractUnitRange}}}
) where {N}
return splitdims(FusionStyle(a), a, bt)

Check warning on line 30 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L30

Added line #L30 was not covered by tests
end

# splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2)
function splitdims(a::AbstractArray, axes::AbstractUnitRange...)
return splitdims(FusionStyle(a), a, axes...)
return splitdims(a, tuple.(axes)...)

Check warning on line 35 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L35

Added line #L35 was not covered by tests
end

# splitdims(randn(4, 4), (1:2, 1:2), (1:2, 1:2))
function splitdims(a::AbstractArray, axesblocks::Tuple{Vararg{AbstractUnitRange}}...)
# TODO: Add `uncanonicalizedims`.
return splitdims(a, flatten_tuples(axesblocks)...)
return splitdims(a, tuplemortar(axesblocks))

Check warning on line 41 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L41

Added line #L41 was not covered by tests
end

# Fix ambiguity issue
splitdims(a::AbstractArray) = a

# splitdims(randn(4, 4), (2, 2), (2, 2))
function splitdims(a::AbstractArray, sizeblocks::Tuple{Vararg{Integer}}...)
return splitdims(a, map(x -> Base.OneTo.(x), sizeblocks)...)
return splitdims(a, tuplemortar(sizeblocks))

Check warning on line 49 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L49

Added line #L49 was not covered by tests
end

# splitdims(randn(4, 4), 2 => (1:2, 1:2))
function splitdims(a::AbstractArray, sizeblocks::Pair...)
return splitdims(a, blockedaxes(a, sizeblocks...)...)
# splitdims(randn(4, 4), tuplemortar(((2, 2), (2, 2))))
function splitdims(

Check warning on line 53 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L53

Added line #L53 was not covered by tests
a::AbstractArray{<:Any,N}, bt::BlockedTuple{N,<:Any,<:Tuple{Vararg{Integer}}}
) where {N}
return splitdims(a, to_axis.(bt))

Check warning on line 56 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L56

Added line #L56 was not covered by tests
end

# TODO: Is this needed?
function splitdims(
a::AbstractArray,
axes_dest::Tuple{Vararg{AbstractUnitRange}},
blockedperm::BlockedPermutation,
)
# TODO: Pass grouped axes.
a_dest_perm = splitdims(a, axes_dest...)
a_dest = _permutedims(a_dest_perm, invperm(Tuple(blockedperm)))
return a_dest
# splitdims(randn(4, 4), 2 => (1:2, 1:2))
function splitdims(a::AbstractArray, sizeblocks::Pair...)
return splitdims(a, blockedaxes(a, sizeblocks...))

Check warning on line 61 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L60-L61

Added lines #L60 - L61 were not covered by tests
end

function splitdims!(
a_dest::AbstractArray, a::AbstractArray, blockedperm::BlockedPermutation
a_dest::AbstractArray, a::AbstractArray, blockedperm::AbstractBlockPermutation
)
axes_dest = map(i -> axes(a_dest, i), Tuple(blockedperm))
# TODO: Pass grouped axes.
a_dest_perm = splitdims(a, axes_dest...)
axes_dest = map(i -> axes(a_dest, i), blockedperm)
a_dest_perm = splitdims(a, axes_dest)

Check warning on line 68 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L67-L68

Added lines #L67 - L68 were not covered by tests
_permutedims!(a_dest, a_dest_perm, invperm(Tuple(blockedperm)))
return a_dest
end
Loading
Loading