diff --git a/Project.toml b/Project.toml index 5d3d28f..bbd6420 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.2.7" +version = "0.2.8" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/contract/blockedperms.jl b/src/contract/blockedperms.jl index 6f33134..a41033a 100644 --- a/src/contract/blockedperms.jl +++ b/src/contract/blockedperms.jl @@ -22,10 +22,10 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) perm_domain2 = BaseExtensions.indexin(domain, dimnames2) permblocks_dest = (perm_codomain_dest, perm_domain_dest) - biperm_dest = blockedpermvcat(filter(!isempty, permblocks_dest)...) + biperm_dest = blockedpermvcat(permblocks_dest...) permblocks1 = (perm_codomain1, perm_domain1) - biperm1 = blockedpermvcat(filter(!isempty, permblocks1)...) + biperm1 = blockedpermvcat(permblocks1...) permblocks2 = (perm_codomain2, perm_domain2) - biperm2 = blockedpermvcat(filter(!isempty, permblocks2)...) + biperm2 = blockedpermvcat(permblocks2...) return biperm_dest, biperm1, biperm2 end diff --git a/src/contract/contract.jl b/src/contract/contract.jl index dae9441..1cfb061 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -13,11 +13,11 @@ default_contract_alg() = Matricize() function contract!( alg::Algorithm, a_dest::AbstractArray, - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation{2}, α::Number, β::Number, ) @@ -110,11 +110,11 @@ end function contract( alg::Algorithm, - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation{2}, α::Number; kwargs..., ) diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index 1750ae2..02d08c9 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -1,103 +1,22 @@ using LinearAlgebra: mul! function contract!( - alg::Matricize, + ::Matricize, a_dest::AbstractArray, - biperm_dest::BlockedPermutation, + biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, - biperm1::BlockedPermutation, + biperm1::AbstractBlockPermutation{2}, a2::AbstractArray, - biperm2::BlockedPermutation, + biperm2::AbstractBlockPermutation{2}, α::Number, β::Number, ) a_dest_mat = fusedims(a_dest, biperm_dest) a1_mat = fusedims(a1, biperm1) a2_mat = fusedims(a2, biperm2) - _mul!(a_dest_mat, a1_mat, a2_mat, α, β) + @assert ndims(a1_mat) == 2 + @assert ndims(a2_mat) == 2 + mul!(a_dest_mat, a1_mat, a2_mat, α, β) splitdims!(a_dest, a_dest_mat, biperm_dest) return a_dest end - -# Matrix multiplication. -function _mul!( - a_dest::AbstractMatrix, a1::AbstractMatrix, a2::AbstractMatrix, α::Number, β::Number -) - mul!(a_dest, a1, a2, α, β) - return a_dest -end - -# Inner product. -function _mul!( - a_dest::AbstractArray{<:Any,0}, - a1::AbstractVector, - a2::AbstractVector, - α::Number, - β::Number, -) - a_dest[] = transpose(a1) * a2 * α + a_dest[] * β - return a_dest -end - -# Vec-mat. -function _mul!( - a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number -) - mul!(transpose(a_dest), transpose(a1), a2, α, β) - return a_dest -end - -# Mat-vec. -function _mul!( - a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number -) - mul!(a_dest, a1, a2, α, β) - return a_dest -end - -# Outer product. -function _mul!( - a_dest::AbstractMatrix, a1::AbstractVector, a2::AbstractVector, α::Number, β::Number -) - mul!(a_dest, a1, transpose(a2), α, β) - return a_dest -end - -# Array-scalar contraction. -function _mul!( - a_dest::AbstractVector, - a1::AbstractVector, - a2::AbstractArray{<:Any,0}, - α::Number, - β::Number, -) - α′ = a2[] * α - a_dest .= a1 .* α′ .+ a_dest .* β - return a_dest -end - -# Scalar-array contraction. -function _mul!( - a_dest::AbstractVector, - a1::AbstractArray{<:Any,0}, - a2::AbstractVector, - α::Number, - β::Number, -) - # Preserve the ordering in case of non-commutative algebra. - a_dest .= a1[] .* a2 .* α .+ a_dest .* β - return a_dest -end - -# Scalar-scalar contraction. -function _mul!( - a_dest::AbstractArray{<:Any,0}, - a1::AbstractArray{<:Any,0}, - a2::AbstractArray{<:Any,0}, - α::Number, - β::Number, -) - # Preserve the ordering in case of non-commutative algebra. - a_dest[] = a1[] * a2[] * α + a_dest[] * β - return a_dest -end diff --git a/src/factorizations.jl b/src/factorizations.jl index 4adcadb..1316cda 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -37,7 +37,9 @@ function qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs.. biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) return qr(A, biperm; kwargs...) end -function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...) +function qr( + A::AbstractArray, biperm::AbstractBlockPermutation{2}; full::Bool=false, kwargs... +) # tensor to matrix A_mat = fusedims(A, biperm) @@ -46,8 +48,8 @@ function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, k # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_Q = (axes_codomain..., axes(Q, 2)) - axes_R = (axes(R, 1), axes_domain...) + axes_Q = tuplemortar((axes_codomain, (axes(q_matricized, 2),))) + axes_R = tuplemortar(((axes(r_matricized, 1),), axes_domain)) return splitdims(Q, axes_Q), splitdims(R, axes_R) end @@ -80,8 +82,8 @@ function lq(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, k # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_L = (axes_codomain..., axes(L, ndims(L))) - axes_Q = (axes(Q, 1), axes_domain...) + axes_L = tuplemortar((axes_codomain, (axes(L, ndims(L)),))) + axes_Q = tuplemortar(((axes(Q, 1),), axes_domain)) return splitdims(L, axes_L), splitdims(Q, axes_Q) end @@ -128,7 +130,7 @@ function eigen( # matrix to tensor axes_codomain, = blockpermute(axes(A), biperm) - axes_V = (axes_codomain..., axes(V, ndims(V))) + axes_V = tuplemortar((axes_codomain, (axes(V, ndims(V)),))) return D, splitdims(V, axes_V) end @@ -202,8 +204,8 @@ function svd( # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) - axes_U = (axes_codomain..., axes(U, 2)) - axes_Vᴴ = (axes(Vᴴ, 1), axes_domain...) + axes_U = tuplemortar((axes_codomain, (axes(U, 2),))) + axes_Vᴴ = tuplemortar(((axes(Vᴴ, 1),), axes_domain)) return splitdims(U, axes_U), S, splitdims(Vᴴ, axes_Vᴴ) end @@ -251,7 +253,7 @@ function left_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) N = left_null!(A_mat; kwargs...) axes_codomain, _ = blockpermute(axes(A), biperm) - axes_N = (axes_codomain..., axes(N, 2)) + axes_N = tuplemortar((axes_codomain, (axes(N, 2),))) N_tensor = splitdims(N, axes_N) return N_tensor end @@ -281,6 +283,6 @@ function right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) A_mat = fusedims(A, biperm) Nᴴ = right_null!(A_mat; kwargs...) _, axes_domain = blockpermute(axes(A), biperm) - axes_Nᴴ = (axes(Nᴴ, 1), axes_domain...) + axes_Nᴴ = tuplemortar((axes(Nᴴ, 1), (axes_domain,))) return splitdims(Nᴴ, axes_Nᴴ) end diff --git a/src/fusedims.jl b/src/fusedims.jl index 87bab85..e0d53fd 100644 --- a/src/fusedims.jl +++ b/src/fusedims.jl @@ -12,6 +12,7 @@ combine_fusion_styles(style1::Style, style2::Style) where {Style<:FusionStyle} = combine_fusion_styles(style1::FusionStyle, style2::FusionStyle) = ReshapeFusion() combine_fusion_styles(styles::FusionStyle...) = foldl(combine_fusion_styles, styles) FusionStyle(axis::AbstractUnitRange) = ReshapeFusion() +FusionStyle(::Tuple{}) = ReshapeFusion() function FusionStyle(axes::Tuple{Vararg{AbstractUnitRange}}) return combine_fusion_styles(FusionStyle.(axes)...) end @@ -27,7 +28,6 @@ function fusedims(a::AbstractArray, ax::AbstractUnitRange, axes::AbstractUnitRan return fusedims(FusionStyle(a), a, ax, axes...) end -# Overload this version for fusion tensors, array maps, etc. function fusedims( a::AbstractArray, axb::Tuple{Vararg{AbstractUnitRange}}, @@ -36,13 +36,6 @@ function fusedims( return fusedims(a, flatten_tuples((axb, axesblocks...))...) end -# Fix ambiguity issue -fusedims(a::AbstractArray{<:Any,0}, ::Vararg{Tuple{}}) = a - -function fusedims(a::AbstractArray, permblocks...) - return fusedims(a, blockedpermvcat(permblocks...; length=Val(ndims(a)))) -end - function fuseaxes( axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation ) @@ -60,7 +53,18 @@ function fusedims(a::AbstractArray, blockedperm::BlockedTrivialPermutation) return fusedims(a, axes_fused) end -function fusedims(a::AbstractArray, blockedperm::BlockedPermutation) - a_perm = _permutedims(a, Tuple(blockedperm)) - return fusedims(a_perm, trivialperm(blockedperm)) +# deal with zero-dim case +fusedims(a::AbstractArray{<:Any,0}, t::Tuple{}...) = reshape(a, ntuple(_ -> 1, length(t))) + +function fusedims(a::AbstractArray, bt::AbstractBlockTuple) + # TBD define permutedims(::AbstractArray, ::AbstractBlockPermutation) + # TBD remove call to BlockedTrivialPermutation? + a_perm = _permutedims(a, Tuple(bt)) + return fusedims(a_perm, trivialperm(bt)) +end + +# fusedims(ones((2,2,2,2)), (3, 1, 2), (4,)) +# fusedims(ones((2,2,2,2)), (3, 1, 2), 4) +function fusedims(a::AbstractArray, permblocks...) + return fusedims(a, blockedpermvcat(permblocks...; length=Val(ndims(a)))) end diff --git a/src/splitdims.jl b/src/splitdims.jl index 0554c61..78684b4 100644 --- a/src/splitdims.jl +++ b/src/splitdims.jl @@ -4,32 +4,41 @@ to_axis(a::AbstractUnitRange) = a to_axis(n::Integer) = Base.OneTo(n) function blockedaxes(a::AbstractArray, sizeblocks::Pair...) - axes_a = axes(a) axes_split = tuple.(axes(a)) for (dim, sizeblock) in sizeblocks # TODO: Handle conversion from length to range! axes_split = Base.setindex(axes_split, to_axis.(sizeblock), dim) end - return axes_split + return tuplemortar(axes_split) end -# splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2) -function splitdims(::ReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...) +function splitdims(::ReshapeFusion, a::AbstractArray, abt::BlockedTuple) # TODO: Add `uncanonicalizedims`. # TODO: Need `length` since `reshape` doesn't accept `axes`, # maybe make a `reshape_axes` function. - return reshape(a, length.(axes)...) + return reshape(a, Tuple(length.(abt))) +end + +# ambiguity for zero-dim +function splitdims(a::AbstractArray{<:Any,N}, abt::BlockedTuple{N,<:Any,Tuple{}}) where {N} + return splitdims(FusionStyle(a), a, abt) +end + +function splitdims( + a::AbstractArray{<:Any,N}, bt::BlockedTuple{N,<:Any,<:Tuple{Vararg{AbstractUnitRange}}} +) where {N} + return splitdims(FusionStyle(a), a, bt) end # splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2) function splitdims(a::AbstractArray, axes::AbstractUnitRange...) - return splitdims(FusionStyle(a), a, axes...) + return splitdims(a, tuple.(axes)...) end # splitdims(randn(4, 4), (1:2, 1:2), (1:2, 1:2)) function splitdims(a::AbstractArray, axesblocks::Tuple{Vararg{AbstractUnitRange}}...) # TODO: Add `uncanonicalizedims`. - return splitdims(a, flatten_tuples(axesblocks)...) + return splitdims(a, tuplemortar(axesblocks)) end # Fix ambiguity issue @@ -37,32 +46,26 @@ splitdims(a::AbstractArray) = a # splitdims(randn(4, 4), (2, 2), (2, 2)) function splitdims(a::AbstractArray, sizeblocks::Tuple{Vararg{Integer}}...) - return splitdims(a, map(x -> Base.OneTo.(x), sizeblocks)...) + return splitdims(a, tuplemortar(sizeblocks)) end -# splitdims(randn(4, 4), 2 => (1:2, 1:2)) -function splitdims(a::AbstractArray, sizeblocks::Pair...) - return splitdims(a, blockedaxes(a, sizeblocks...)...) +# splitdims(randn(4, 4), tuplemortar(((2, 2), (2, 2)))) +function splitdims( + a::AbstractArray{<:Any,N}, bt::BlockedTuple{N,<:Any,<:Tuple{Vararg{Integer}}} +) where {N} + return splitdims(a, to_axis.(bt)) end -# TODO: Is this needed? -function splitdims( - a::AbstractArray, - axes_dest::Tuple{Vararg{AbstractUnitRange}}, - blockedperm::BlockedPermutation, -) - # TODO: Pass grouped axes. - a_dest_perm = splitdims(a, axes_dest...) - a_dest = _permutedims(a_dest_perm, invperm(Tuple(blockedperm))) - return a_dest +# splitdims(randn(4, 4), 2 => (1:2, 1:2)) +function splitdims(a::AbstractArray, sizeblocks::Pair...) + return splitdims(a, blockedaxes(a, sizeblocks...)) end function splitdims!( - a_dest::AbstractArray, a::AbstractArray, blockedperm::BlockedPermutation + a_dest::AbstractArray, a::AbstractArray, blockedperm::AbstractBlockPermutation ) - axes_dest = map(i -> axes(a_dest, i), Tuple(blockedperm)) - # TODO: Pass grouped axes. - a_dest_perm = splitdims(a, axes_dest...) + axes_dest = map(i -> axes(a_dest, i), blockedperm) + a_dest_perm = splitdims(a, axes_dest) _permutedims!(a_dest, a_dest_perm, invperm(Tuple(blockedperm))) return a_dest end diff --git a/test/test_basics.jl b/test/test_basics.jl index 3bce1b7..b21a5d6 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,7 +1,8 @@ using EllipsisNotation: var".." using LinearAlgebra: norm using StableRNGs: StableRNG -using TensorAlgebra: contract, contract!, fusedims, qr, splitdims, svd +using TensorAlgebra: + blockedperm, contract, contract!, fusedims, qr, splitdims, svd, tuplemortar using TensorOperations: TensorOperations using Test: @test, @test_broken, @testset @@ -11,6 +12,16 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "TensorAlgebra" begin @testset "fusedims (eltype=$elt)" for elt in elts a = randn(elt, 2, 3, 4, 5) + + bt = tuplemortar(((3, 2), (4, 1))) + p = blockedperm(bt) + a_fused = fusedims(a, p) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(permutedims(a, (3, 2, 4, 1)), (12, 10)) + a_fused = fusedims(a, bt) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(permutedims(a, (3, 2, 4, 1)), (12, 10)) + a_fused = fusedims(a, (1, 2), (3, 4)) @test eltype(a_fused) === elt @test a_fused ≈ reshape(a, 6, 20) @@ -35,9 +46,30 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a_fused = fusedims(a, (3, 1), (..,)) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) + + a = randn(elt, ()) + a_fused = fusedims(a) + @test a_fused isa Array{elt,0} + @test a_fused ≈ a + a_fused = fusedims(a, ()) + @test a_fused isa Array{elt,1} + a_fused = fusedims(a, (), ()) + @test a_fused isa Array{elt,2} + a_fused = fusedims(a, tuplemortar(((), ()))) + @test a_fused isa Array{elt,2} end + @testset "splitdims (eltype=$elt)" for elt in elts a = randn(elt, 6, 20) + + a_split = splitdims(a, tuplemortar(((1:2, 1:3), (1:5, 1:4)))) + @test a_split isa Array{elt,4} + @test a_split ≈ reshape(a, (2, 3, 5, 4)) + + a_split = splitdims(a, tuplemortar(((2, 3), (5, 4)))) + @test a_split isa Array{elt,4} + @test a_split ≈ reshape(a, (2, 3, 5, 4)) + a_split = splitdims(a, (2, 3), (5, 4)) @test eltype(a_split) === elt @test a_split ≈ reshape(a, (2, 3, 5, 4)) @@ -62,6 +94,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a_split = splitdims(a, 1 => (1:2, 1:3)) @test eltype(a_split) === elt @test a_split ≈ reshape(a, (2, 3, 20)) + + a_split = splitdims(a) + @test a_split isa Array{elt,2} + @test a_split ≈ a end using TensorOperations: TensorOperations @testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts