Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.7"
version = "0.3.8"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
16 changes: 16 additions & 0 deletions src/contract/allocate_output.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
using Base.PermutedDimsArrays: genperm

function check_input(::typeof(contract), a1, labels1, a2, labels2)
ndims(a1) == length(labels1) ||
throw(ArgumentError("Invalid permutation for left tensor"))
return ndims(a2) == length(labels2) ||
throw(ArgumentError("Invalid permutation for right tensor"))
end

function check_input(::typeof(contract), a_dest, labels_dest, a1, labels1, a2, labels2)
ndims(a_dest) == length(labels_dest) ||
throw(ArgumentError("Invalid permutation for destination tensor"))
return check_input(contract, a1, labels1, a2, labels2)
end

# TODO: Use `ArrayLayouts`-like `MulAdd` object,
# i.e. `ContractAdd`?
function output_axes(
Expand Down Expand Up @@ -28,6 +41,9 @@ function allocate_output(
biperm2::AbstractBlockPermutation,
α::Number=one(Bool),
)
check_input(contract, a1, biperm1, a2, biperm2)
blocklengths(biperm_dest) == (length(biperm1[Block(1)]), length(biperm2[Block(2)])) ||
throw(ArgumentError("Invalid permutation for destination tensor"))
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
end
5 changes: 5 additions & 0 deletions src/contract/blockedperms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ end

# codomain <-- domain
function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
dimnames = collect(Iterators.flatten((dimnames_dest, dimnames1, dimnames2)))
for i in unique(dimnames)
count(==(i), dimnames) == 2 || throw(ArgumentError("Invalid contraction labels"))
end

codomain = Tuple(setdiff(dimnames1, dimnames2))
contracted = Tuple(intersect(dimnames1, dimnames2))
domain = Tuple(setdiff(dimnames2, dimnames1))
Expand Down
3 changes: 3 additions & 0 deletions src/contract/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ function contract(
α::Number=one(Bool);
kwargs...,
)
check_input(contract, a1, labels1, a2, labels2)
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
return contract(alg, biperm_dest, a1, biperm1, a2, biperm2, α; kwargs...)
end
Expand All @@ -104,6 +105,7 @@ function contract!(
β::Number;
kwargs...,
)
check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2)
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
return contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
end
Expand All @@ -118,6 +120,7 @@ function contract(
α::Number;
kwargs...,
)
check_input(contract, a1, biperm1, a2, biperm2)
a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...)
return a_dest
Expand Down
1 change: 1 addition & 0 deletions src/contract/contract_matricize/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ function contract!(
α::Number,
β::Number,
)
check_input(contract, a_dest, biperm_dest, a1, biperm1, a2, biperm2)
a_dest_mat = matricize(a_dest, biperm_dest)
a1_mat = matricize(a1, biperm1)
a2_mat = matricize(a2, biperm2)
Expand Down
6 changes: 4 additions & 2 deletions src/contract/output_labels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ function output_labels(
return output_labels(f, alg, labels1, labels2)
end

function output_labels(f::typeof(contract), alg::Algorithm, labels1, labels2)
function output_labels(f::typeof(contract), ::Algorithm, labels1, labels2)
return output_labels(f, labels1, labels2)
end

function output_labels(::typeof(contract), labels1, labels2)
return Tuple(symdiff(labels1, labels2))
diff1 = Tuple(setdiff(labels1, labels2))
diff2 = Tuple(setdiff(labels2, labels1))
return tuplemortar((diff1, diff2))
end
4 changes: 4 additions & 0 deletions src/matricize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ end
# maybe: copy=false kwarg

function matricize(a::AbstractArray, biperm::AbstractBlockPermutation{2})
ndims(a) == length(biperm) || throw(ArgumentError("Invalid bipermutation"))
return matricize(FusionStyle(a), a, biperm)
end

Expand Down Expand Up @@ -78,6 +79,7 @@ function unmatricize(
axes::Tuple{Vararg{AbstractUnitRange}},
biperm::AbstractBlockPermutation{2},
)
length(axes) == length(biperm) || throw(ArgumentError("axes do not match permutation"))
return unmatricize(FusionStyle(m), m, axes, biperm)
end

Expand Down Expand Up @@ -122,6 +124,8 @@ end
function unmatricize!(
a::AbstractArray, m::AbstractMatrix, biperm::AbstractBlockPermutation{2}
)
ndims(a) == length(biperm) ||
throw(ArgumentError("destination does not match permutation"))
blocked_axes = axes(a)[biperm]
a_perm = unmatricize(m, blocked_axes)
return permuteblockeddims!(a, a_perm, invperm(biperm))
Expand Down
39 changes: 33 additions & 6 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using StableRNGs: StableRNG
using TensorOperations: TensorOperations

using TensorAlgebra:
BlockedTuple,
blockedpermvcat,
permuteblockeddims,
permuteblockeddims!,
Expand Down Expand Up @@ -61,6 +62,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

@test_throws MethodError matricize(a, (1, 2), (3,), (4,))
@test_throws MethodError matricize(a, (1, 2, 3, 4))
@test_throws ArgumentError matricize(a, blockedpermvcat((1, 2), (3,)))

v = ones(elt, 2)
a_fused = matricize(v, (1,), ())
Expand Down Expand Up @@ -122,10 +124,23 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a = unmatricize(m, (), ())
@test a isa Array{elt,0}
@test a[] == m[1, 1]

@test_throws ArgumentError unmatricize(m, (), blockedpermvcat((1, 2), (3,)))
@test_throws ArgumentError unmatricize!(m, m, blockedpermvcat((1, 2), (3,)))
end

using TensorOperations: TensorOperations
@testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts
elt_dest = promote_type(elt1, elt2)
a1 = ones(elt1, (1, 1))
a2 = ones(elt2, (1, 1))
a_dest = ones(elt_dest, (1, 1))
@test_throws ArgumentError contract(a1, (1, 2, 4), a2, (2, 3))
@test_throws ArgumentError contract(a1, (1, 2), a2, (2, 3, 4))
@test_throws ArgumentError contract((1, 3, 4), a1, (1, 2), a2, (2, 3))
@test_throws ArgumentError contract((1, 3), a1, (1, 2), a2, (2, 4))
@test_throws ArgumentError contract!(a_dest, (1, 3, 4), a1, (1, 2), a2, (2, 3))

dims = (2, 3, 4, 5, 6, 7, 8, 9, 10)
labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i)
for (d1s, d2s, d_dests) in (
Expand Down Expand Up @@ -155,8 +170,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

# Don't specify destination labels
a_dest, labels_dest′ = contract(a1, labels1, a2, labels2)
@test labels_dest′ isa
BlockedTuple{2,(length(setdiff(d1s, d2s)), length(setdiff(d2s, d1s)))}
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest′, a1, labels1, a2, labels2
Tuple(labels_dest′), a1, labels1, a2, labels2
)
@test a_dest ≈ a_dest_tensoroperations

Expand All @@ -167,8 +184,18 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
)
@test a_dest ≈ a_dest_tensoroperations

# Specify with bituple
a_dest = contract(tuplemortar((labels_dest, ())), a1, labels1, a2, labels2)
@test a_dest ≈ a_dest_tensoroperations
a_dest = contract(tuplemortar(((), labels_dest)), a1, labels1, a2, labels2)
@test a_dest ≈ a_dest_tensoroperations
a_dest = contract(labels_dest′, a1, labels1, a2, labels2)
a_dest_tensoroperations = TensorOperations.tensorcontract(
Tuple(labels_dest′), a1, labels1, a2, labels2
)
@test a_dest ≈ a_dest_tensoroperations

# Specify α and β
elt_dest = promote_type(elt1, elt2)
# TODO: Using random `α`, `β` causing
# random test failures, investigate why.
α = elt_dest(1.2) # randn(elt_dest)
Expand All @@ -195,7 +222,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a2 = randn(rng, elt2, 4, 5)

a_dest, labels = contract(a1, ("i", "j"), a2, ("k", "l"))
@test labels == ("i", "j", "k", "l")
@test labels == tuplemortar((("i", "j"), ("k", "l")))
@test eltype(a_dest) === elt_dest
@test a_dest ≈ reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...))

Expand Down Expand Up @@ -225,17 +252,17 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

# Array-scalar contraction.
a_dest, labels_dest = contract(a, labels_a, s, ())
@test labels_dest == labels_a
@test labels_dest == tuplemortar((labels_a, ()))
@test a_dest ≈ a * s[]

# Scalar-array contraction.
a_dest, labels_dest = contract(s, (), a, labels_a)
@test labels_dest == labels_a
@test labels_dest == tuplemortar(((), labels_a))
@test a_dest ≈ a * s[]

# Scalar-scalar contraction.
a_dest, labels_dest = contract(s, (), t, ())
@test labels_dest == ()
@test labels_dest == tuplemortar(((), ()))
@test a_dest[] ≈ s[] * t[]

# Specify output labels.
Expand Down
Loading