diff --git a/Project.toml b/Project.toml index c0e814d..19abea2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.1.3" +version = "0.1.4" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index 380eccc..663f480 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -3,6 +3,11 @@ using BlockArrays: using EllipsisNotation: Ellipsis, var".." using TupleTools: TupleTools +trivialperm(len) = ntuple(identity, len) +function istrivialperm(t::Tuple) + return t == trivialperm(length(t)) +end + value(::Val{N}) where {N} = N _flatten_tuples(t::Tuple) = t diff --git a/src/contract/allocate_output.jl b/src/contract/allocate_output.jl index 2beff4c..3d9efa3 100644 --- a/src/contract/allocate_output.jl +++ b/src/contract/allocate_output.jl @@ -67,6 +67,22 @@ function output_axes( return genperm((axes_dest...,), invperm(Tuple(perm_dest))) end +# Outer product. +function output_axes( + ::typeof(contract), + biperm_dest::BlockedPermutation{2}, + a1::AbstractArray, + perm1::BlockedPermutation{1}, + a2::AbstractArray, + perm2::BlockedPermutation{1}, + α::Number=true, +) + @assert istrivialperm(Tuple(perm1)) + @assert istrivialperm(Tuple(perm2)) + axes_dest = (axes(a1)..., axes(a2)...) + return genperm(axes_dest, invperm(Tuple(biperm_dest))) +end + # TODO: Use `ArrayLayouts`-like `MulAdd` object, # i.e. `ContractAdd`? function allocate_output( diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index f56e7e0..5fffd2b 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -54,3 +54,11 @@ function _mul!( mul!(a_dest, a1, a2, α, β) return a_dest end + +# Outer product. +function _mul!( + a_dest::AbstractMatrix, a1::AbstractVector, a2::AbstractVector, α::Number, β::Number +) + mul!(a_dest, a1, transpose(a2), α, β) + return a_dest +end diff --git a/test/test_basics.jl b/test/test_basics.jl index 5a5bda2..09fb26a 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -121,6 +121,31 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) 50 * default_rtol(elt_dest) end end + @testset "outer product contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, + elt2 in elts + + a1 = randn(elt1, 2, 3) + a2 = randn(elt2, 4, 5) + + elt_dest = promote_type(elt1, elt2) + + a_dest, labels = TensorAlgebra.contract(a1, ("i", "j"), a2, ("k", "l")) + @test labels == ("i", "j", "k", "l") + @test eltype(a_dest) === elt_dest + @test a_dest ≈ reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)) + + a_dest = TensorAlgebra.contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l")) + @test eltype(a_dest) === elt_dest + @test a_dest ≈ permutedims( + reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 3, 2, 4) + ) + + a_dest = zeros(elt_dest, 2, 5, 3, 4) + TensorAlgebra.contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l")) + @test a_dest ≈ permutedims( + reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3) + ) + end end @testset "qr (eltype=$elt)" for elt in elts a = randn(elt, 5, 4, 3, 2)