Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.7"
version = "0.3.8"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
16 changes: 16 additions & 0 deletions src/contract/allocate_output.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
using Base.PermutedDimsArrays: genperm

function checkndims(a1, labels1, a2, labels2)
ndims(a1) == length(labels1) ||
throw(ArgumentError("Invalid permutation for left tensor"))
return ndims(a2) == length(labels2) ||
throw(ArgumentError("Invalid permutation for right tensor"))
end

function checkndims(a_dest, labels_dest, a1, labels1, a2, labels2)
ndims(a_dest) == length(labels_dest) ||
throw(ArgumentError("Invalid permutation for destination tensor"))
return checkndims(a1, labels1, a2, labels2)
end

# TODO: Use `ArrayLayouts`-like `MulAdd` object,
# i.e. `ContractAdd`?
function output_axes(
Expand Down Expand Up @@ -28,6 +41,9 @@ function allocate_output(
biperm2::AbstractBlockPermutation,
α::Number=one(Bool),
)
checkndims(a1, biperm1, a2, biperm2)
blocklengths(biperm_dest) == (length(biperm1[Block(1)]), length(biperm2[Block(2)])) ||
throw(ArgumentError("Invalid permutation for destination tensor"))
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
end
5 changes: 5 additions & 0 deletions src/contract/blockedperms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ end

# codomain <-- domain
function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
dimnames = collect(Iterators.flatten((dimnames_dest, dimnames1, dimnames2)))
for i in unique(dimnames)
count(==(i), dimnames) == 2 || throw(ArgumentError("Invalid contraction labels"))
end

codomain = Tuple(setdiff(dimnames1, dimnames2))
contracted = Tuple(intersect(dimnames1, dimnames2))
domain = Tuple(setdiff(dimnames2, dimnames1))
Expand Down
3 changes: 3 additions & 0 deletions src/contract/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ function contract(
α::Number=one(Bool);
kwargs...,
)
checkndims(a1, labels1, a2, labels2)
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
return contract(alg, biperm_dest, a1, biperm1, a2, biperm2, α; kwargs...)
end
Expand All @@ -104,6 +105,7 @@ function contract!(
β::Number;
kwargs...,
)
checkndims(a_dest, labels_dest, a1, labels1, a2, labels2)
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
return contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
end
Expand All @@ -118,6 +120,7 @@ function contract(
α::Number;
kwargs...,
)
checkndims(a1, biperm1, a2, biperm2)
a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...)
return a_dest
Expand Down
1 change: 1 addition & 0 deletions src/contract/contract_matricize/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ function contract!(
α::Number,
β::Number,
)
checkndims(a_dest, biperm_dest, a1, biperm1, a2, biperm2)
a_dest_mat = matricize(a_dest, biperm_dest)
a1_mat = matricize(a1, biperm1)
a2_mat = matricize(a2, biperm2)
Expand Down
5 changes: 3 additions & 2 deletions src/contract/output_labels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ function output_labels(
return output_labels(f, alg, labels1, labels2)
end

function output_labels(f::typeof(contract), alg::Algorithm, labels1, labels2)
function output_labels(f::typeof(contract), ::Algorithm, labels1, labels2)
return output_labels(f, labels1, labels2)
end

function output_labels(::typeof(contract), labels1, labels2)
return Tuple(symdiff(labels1, labels2))
diff = symdiff(labels1, labels2)
return tuplemortar((Tuple(intersect(diff, labels1)), Tuple(intersect(diff, labels2))))
end
4 changes: 4 additions & 0 deletions src/matricize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ end
# maybe: copy=false kwarg

function matricize(a::AbstractArray, biperm::AbstractBlockPermutation{2})
ndims(a) == length(biperm) || throw(ArgumentError("Invalid bipermutation"))
return matricize(FusionStyle(a), a, biperm)
end

Expand Down Expand Up @@ -78,6 +79,7 @@ function unmatricize(
axes::Tuple{Vararg{AbstractUnitRange}},
biperm::AbstractBlockPermutation{2},
)
length(axes) == length(biperm) || throw(ArgumentError("axes do not match permutation"))
return unmatricize(FusionStyle(m), m, axes, biperm)
end

Expand Down Expand Up @@ -122,6 +124,8 @@ end
function unmatricize!(
a::AbstractArray, m::AbstractMatrix, biperm::AbstractBlockPermutation{2}
)
ndims(a) == length(biperm) ||
throw(ArgumentError("destination does not match permutation"))
blocked_axes = axes(a)[biperm]
a_perm = unmatricize(m, blocked_axes)
return permuteblockeddims!(a, a_perm, invperm(biperm))
Expand Down
39 changes: 33 additions & 6 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using StableRNGs: StableRNG
using TensorOperations: TensorOperations

using TensorAlgebra:
BlockedTuple,
blockedpermvcat,
permuteblockeddims,
permuteblockeddims!,
Expand Down Expand Up @@ -61,6 +62,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

@test_throws MethodError matricize(a, (1, 2), (3,), (4,))
@test_throws MethodError matricize(a, (1, 2, 3, 4))
@test_throws ArgumentError matricize(a, blockedpermvcat((1, 2), (3,)))

v = ones(elt, 2)
a_fused = matricize(v, (1,), ())
Expand Down Expand Up @@ -122,10 +124,23 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a = unmatricize(m, (), ())
@test a isa Array{elt,0}
@test a[] == m[1, 1]

@test_throws ArgumentError unmatricize(m, (), blockedpermvcat((1, 2), (3,)))
@test_throws ArgumentError unmatricize!(m, m, blockedpermvcat((1, 2), (3,)))
end

using TensorOperations: TensorOperations
@testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts
elt_dest = promote_type(elt1, elt2)
a1 = ones(elt1, (1, 1))
a2 = ones(elt2, (1, 1))
a_dest = ones(elt_dest, (1, 1))
@test_throws ArgumentError contract(a1, (1, 2, 4), a2, (2, 3))
@test_throws ArgumentError contract(a1, (1, 2), a2, (2, 3, 4))
@test_throws ArgumentError contract((1, 3, 4), a1, (1, 2), a2, (2, 3))
@test_throws ArgumentError contract((1, 3), a1, (1, 2), a2, (2, 4))
@test_throws ArgumentError contract!(a_dest, (1, 3, 4), a1, (1, 2), a2, (2, 3))

dims = (2, 3, 4, 5, 6, 7, 8, 9, 10)
labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i)
for (d1s, d2s, d_dests) in (
Expand Down Expand Up @@ -155,8 +170,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

# Don't specify destination labels
a_dest, labels_dest′ = contract(a1, labels1, a2, labels2)
@test labels_dest′ isa
BlockedTuple{2,(length(setdiff(d1s, d2s)), length(setdiff(d2s, d1s)))}
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest′, a1, labels1, a2, labels2
Tuple(labels_dest′), a1, labels1, a2, labels2
)
@test a_dest ≈ a_dest_tensoroperations

Expand All @@ -167,8 +184,18 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
)
@test a_dest ≈ a_dest_tensoroperations

# Specify with bituple
a_dest = contract(tuplemortar((labels_dest, ())), a1, labels1, a2, labels2)
@test a_dest ≈ a_dest_tensoroperations
a_dest = contract(tuplemortar(((), labels_dest)), a1, labels1, a2, labels2)
@test a_dest ≈ a_dest_tensoroperations
a_dest = contract(labels_dest′, a1, labels1, a2, labels2)
a_dest_tensoroperations = TensorOperations.tensorcontract(
Tuple(labels_dest′), a1, labels1, a2, labels2
)
@test a_dest ≈ a_dest_tensoroperations

# Specify α and β
elt_dest = promote_type(elt1, elt2)
# TODO: Using random `α`, `β` causing
# random test failures, investigate why.
α = elt_dest(1.2) # randn(elt_dest)
Expand All @@ -195,7 +222,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a2 = randn(rng, elt2, 4, 5)

a_dest, labels = contract(a1, ("i", "j"), a2, ("k", "l"))
@test labels == ("i", "j", "k", "l")
@test labels == tuplemortar((("i", "j"), ("k", "l")))
@test eltype(a_dest) === elt_dest
@test a_dest ≈ reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...))

Expand Down Expand Up @@ -225,17 +252,17 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

# Array-scalar contraction.
a_dest, labels_dest = contract(a, labels_a, s, ())
@test labels_dest == labels_a
@test labels_dest == tuplemortar((labels_a, ()))
@test a_dest ≈ a * s[]

# Scalar-array contraction.
a_dest, labels_dest = contract(s, (), a, labels_a)
@test labels_dest == labels_a
@test labels_dest == tuplemortar(((), labels_a))
@test a_dest ≈ a * s[]

# Scalar-scalar contraction.
a_dest, labels_dest = contract(s, (), t, ())
@test labels_dest == ()
@test labels_dest == tuplemortar(((), ()))
@test a_dest[] ≈ s[] * t[]

# Specify output labels.
Expand Down
Loading