Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FusionTensors"
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.7"
version = "0.5.8"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -28,7 +28,7 @@ LRUCache = "1.6"
LinearAlgebra = "1.10"
Random = "1.10"
Strided = "2.3"
TensorAlgebra = "0.3.8"
TensorAlgebra = "0.3.12"
TensorProducts = "0.1.7"
TypeParameterAccessors = "0.4"
WignerSymbols = "2.0.0"
Expand Down
7 changes: 5 additions & 2 deletions src/fusiontensor/base_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,13 @@ function Base.similar(::FusionTensor, ::Type{T}, new_axes::FusionTensorAxes) whe
return FusionTensor{T}(undef, new_axes)
end

Base.show(io::IO, ft::FusionTensor) = print(io, "$(ndims(ft))-dim FusionTensor")
function Base.show(io::IO, ft::FusionTensor)
return print(io, "$(ndims(ft))-dim FusionTensor with size $(size(ft))")
end

function Base.show(io::IO, ::MIME"text/plain", ft::FusionTensor)
print(io, "$(ndims(ft))-dim FusionTensor with axes:")
print(io, ft)
print(" and axes:")
for ax in axes(ft)
print(io, "\n", ax)
end
Expand Down
6 changes: 5 additions & 1 deletion src/fusiontensor/fusiontensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using GradedArrays:
space_isequal
using LinearAlgebra: UniformScaling
using Random: Random, AbstractRNG, randn!
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar, length_codomain, length_domain
using TensorProducts: tensor_product
using TypeParameterAccessors: type_parameters

Expand Down Expand Up @@ -134,6 +134,10 @@ struct FusionTensor{T,N,Axes<:FusionTensorAxes,Mat<:AbstractMatrix{T},Mapping} <
end
end

const FusionMatrix{T,Axes,Mat,Mapping} = FusionTensor{
T,2,Axes,Mapping
} where {BT<:BlockedTuple{2,(1, 1)},Axes<:FusionTensorAxes{BT}}

# ===================================== Accessors ========================================

data_matrix(ft::FusionTensor) = ft.data_matrix
Expand Down
14 changes: 9 additions & 5 deletions src/fusiontensor/fusiontensoraxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ using GradedArrays:
dual,
sector_type,
trivial
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTuple
using TensorAlgebra:
TensorAlgebra,
AbstractBlockPermutation,
AbstractBlockTuple,
BlockedTuple,
length_codomain,
length_domain
using TensorProducts: ⊗
using TypeParameterAccessors: type_parameters

Expand Down Expand Up @@ -65,6 +71,8 @@ TensorAlgebra.BlockedTuple(fta::FusionTensorAxes) = fta.outer_axes

TensorAlgebra.trivial_axis(fta::FusionTensorAxes) = trivial_axis(sector_type(fta))

TensorAlgebra.length_domain(fta::FusionTensorAxes) = length(domain(fta))

# ================================== Base interface ======================================

for f in [
Expand Down Expand Up @@ -140,7 +148,3 @@ function fused_domain(fta::FusionTensorAxes)
end
return dual(⊗(dual.(domain(fta))...))
end

length_codomain(fta::FusionTensorAxes) = length(codomain(fta))

length_domain(fta::FusionTensorAxes) = length(domain(fta))
3 changes: 1 addition & 2 deletions src/fusiontensor/linear_algebra_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ using GradedArrays: checkspaces, checkspaces_dual, quantum_dimension, sectors
# allow to contract with different eltype and let BlockSparseArray ensure compatibility
# impose matching type and number of axes at compile time
# impose matching axes at run time
# TODO remove this once TensorAlgebra.contract can be used?
function LinearAlgebra.mul!(
C::FusionTensor, A::FusionTensor, B::FusionTensor, α::Number, β::Number
C::FusionMatrix, A::FusionMatrix, B::FusionMatrix, α::Number, β::Number
)

# compile time checks
Expand Down
97 changes: 64 additions & 33 deletions src/fusiontensor/tensor_algebra_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,78 @@ using LinearAlgebra: mul!

using BlockArrays: Block

using TensorAlgebra: BlockedPermutation, Matricize, TensorAlgebra
using GradedArrays: space_isequal
using TensorAlgebra:
TensorAlgebra,
AbstractBlockPermutation,
BlockedTrivialPermutation,
BlockedTuple,
FusionStyle,
Matricize,
blockedperm,
genperm,
unmatricize

# TODO how to deal with inner contraction = no ouput axis?
# => currently biperm_dest is a BlockedPermutation{0}, change this
function TensorAlgebra.allocate_output(
function TensorAlgebra.output_axes(
::typeof(contract),
biperm_dest::BlockedPermutation{2},
biperm_dest::AbstractBlockPermutation{2},
a1::FusionTensor,
biperm1::BlockedPermutation{2},
biperm1::AbstractBlockPermutation{2},
a2::FusionTensor,
biperm2::BlockedPermutation{2},
α::Number=true,
biperm2::AbstractBlockPermutation{2},
α::Number=one(Bool),
)
axes_dest = (
map(i -> axes(a1)[i], first(blocks(biperm1))),
map(i -> axes(a2)[i], last(blocks(biperm2))),
axes_codomain, axes_contracted = blocks(axes(a1)[biperm1])
axes_contracted2, axes_domain = blocks(axes(a2)[biperm2])
@assert all(space_isequal.(dual.(axes_contracted), axes_contracted2))
flat_axes = genperm((axes_codomain..., axes_domain...), Tuple(biperm_dest))
return FusionTensorAxes(
tuplemortar((
flat_axes[begin:length_codomain(biperm_dest)],
flat_axes[(length_codomain(biperm_dest) + 1):end],
)),
)
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
end

# TBD do really I need to define these as I cannot use them in contract! and has to redefine it?
#TensorAlgebra.fusedims(ft::FusionTensor, perm::BlockedPermutation{2}) = permutedims(ft, perm)
#function TensorAlgebra.splitdims(ft1::FusionTensor, ft2::FusionTensor, blockedperm::BlockedPermutation)
#function TensorAlgebra.splitdims!(ft1::FusionTensor, ft2::FusionTensor, blockedperm::BlockedPermutation)

# I cannot use contract! from TensorAlgebra/src/contract/contract_matricize/contract.jl
# as it calls _mul!, which I should not overload.
# TBD define fallback _mul!(::AbstractArray, ::AbstractArray, ::AbstractArray) in TensorAlgebra?
function TensorAlgebra.contract!(
::Matricize,
a_dest::FusionTensor,
::BlockedPermutation{2},
a1::FusionTensor,
biperm1::BlockedPermutation{2},
a2::FusionTensor,
biperm2::BlockedPermutation{2},
α::Number,
β::Number,
struct FusionTensorFusionStyle <: FusionStyle end

TensorAlgebra.FusionStyle(::Type{<:FusionTensor}) = FusionTensorFusionStyle()

function TensorAlgebra.matricize(
::FusionTensorFusionStyle, ft::AbstractArray, biperm::AbstractBlockPermutation{2}
)
permuted = permutedims(ft, biperm)
return FusionTensor(
data_matrix(permuted), (codomain_axis(permuted),), (domain_axis(permuted),)
)
end

# lift ambiguity
function TensorAlgebra.matricize(
::FusionTensorFusionStyle, ft::AbstractArray, biperm::BlockedTrivialPermutation{2}
)
return matricize(FusionTensorFusionStyle(), ft, blockedperm(BlockedTuple(tbp)))
end

function TensorAlgebra.unmatricize(::FusionTensorFusionStyle, m, blocked_axes)
return FusionTensor(data_matrix(m), blocked_axes)
end

function TensorAlgebra.permuteblockeddims(
ft::FusionTensor, biperm::AbstractBlockPermutation
)
a1_perm = permutedims(a1, biperm1)
a2_perm = permutedims(a2, biperm2)
mul!(a_dest, a1_perm, a2_perm, α, β)
return permutedims(ft, biperm)
end

function TensorAlgebra.permuteblockeddims!(
a::FusionTensor, b::FusionTensor, biperm::AbstractBlockPermutation
)
return permutedims!(a, b, biperm)
end

# TODO define custom broadcast rules
function TensorAlgebra.unmatricize_add!(a_dest::FusionTensor, a_dest_mat, invbiperm, α, β)
a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm)
data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest)
return a_dest
end
17 changes: 17 additions & 0 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Test: @test, @test_throws, @testset
using BlockArrays: Block
using BlockSparseArrays: BlockSparseArray, eachblockstoredindex
using FusionTensors:
FusionMatrix,
FusionTensor,
FusionTensorAxes,
codomain_axes,
Expand Down Expand Up @@ -40,6 +41,7 @@ include("setup.jl")
fta = FusionTensorAxes((g1,), (g2,))
ft0 = FusionTensor{Float64}(undef, fta)
@test ft0 isa FusionTensor
@test ft0 isa FusionMatrix
@test space_isequal(codomain_axis(ft0), g1)
@test space_isequal(domain_axis(ft0), g2)

Expand Down Expand Up @@ -134,6 +136,8 @@ end
m2 = BlockSparseArray{Float64}(undef, gr, gc)
ft = FusionTensor(m2, (g1, g2), (g3, g4))

@test ft isa FusionTensor
@test !(ft isa FusionMatrix)
@test data_matrix(ft) == m2
@test checkspaces(codomain_axes(ft), (g1, g2))
@test checkspaces(domain_axes(ft), (g3, g4))
Expand All @@ -155,6 +159,8 @@ end

# one row axis
ft1 = FusionTensor{Float64}(undef, (g1,), ())
@test ft1 isa FusionTensor
@test !(ft1 isa FusionMatrix)
@test ndims_codomain(ft1) == 1
@test ndims_domain(ft1) == 0
@test ndims(ft1) == 1
Expand All @@ -165,6 +171,8 @@ end

# one column axis
ft2 = FusionTensor{Float64}(undef, (), (g1,))
@test ft2 isa FusionTensor
@test !(ft2 isa FusionMatrix)
@test ndims_codomain(ft2) == 0
@test ndims_domain(ft2) == 1
@test ndims(ft2) == 1
Expand All @@ -175,13 +183,22 @@ end

# zero axis
ft3 = FusionTensor{Float64}(undef, (), ())
@test ft3 isa FusionTensor
@test !(ft3 isa FusionMatrix)
@test ndims_codomain(ft3) == 0
@test ndims_domain(ft3) == 0
@test ndims(ft3) == 0
@test size(ft3) == tuplemortar(((), ()))
@test size(data_matrix(ft3)) == (1, 1)
@test isnothing(check_sanity(ft3))
@test sector_type(ft3) === TrivialSector

ft4 = FusionTensor{Float64}(undef, (g1, g1), ())
@test ft4 isa FusionTensor
@test !(ft4 isa FusionMatrix)
ft5 = FusionTensor{Float64}(undef, (), (g1, g1))
@test ft5 isa FusionTensor
@test !(ft5 isa FusionMatrix)
end

@testset "specific constructors" begin
Expand Down
67 changes: 53 additions & 14 deletions test/test_contraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,37 @@ using LinearAlgebra: mul!
using Test: @test, @testset, @test_broken

using BlockSparseArrays: BlockSparseArray
using FusionTensors: FusionTensor, domain_axes, codomain_axes
using FusionTensors:
FusionMatrix, FusionTensor, FusionTensorAxes, domain_axes, codomain_axes
using GradedArrays: U1, dual, gradedrange
using TensorAlgebra: contract, tuplemortar
using TensorAlgebra: contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize!

include("setup.jl")

@testset "matricize" begin
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])

ft1 = randn(FusionTensorAxes((g1, g2), (dual(g3), dual(g4))))
m = matricize(ft1, (1, 2), (3, 4))
@test m isa FusionMatrix
ft2 = unmatricize(m, axes(ft1))
@test ft1 ≈ ft2

biperm = permmortar(((3,), (1, 2, 4)))
m2 = matricize(ft1, biperm)
ft_dest = FusionTensor{eltype(ft1)}(undef, axes(ft1)[biperm])
unmatricize!(ft_dest, m2, permmortar(((1,), (2, 3, 4))))
@test ft_dest ≈ permutedims(ft1, biperm)
@test ft_dest ≈ permutedims(ft1, biperm)

ft2 = similar(ft1)
unmatricize!(ft2, m2, biperm)
@test ft1 ≈ ft2
end

@testset "contraction" begin
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
Expand All @@ -26,15 +51,12 @@ include("setup.jl")
@test codomain_axes(ft3) === codomain_axes(ft1)

# test LinearAlgebra.mul! with in-place matrix product
mul!(ft3, ft1, ft2)
@test isnothing(check_sanity(ft3))
@test domain_axes(ft3) === domain_axes(ft2)
@test codomain_axes(ft3) === codomain_axes(ft1)
m1 = randn(FusionTensorAxes((g1,), (g2,)))
m2 = randn(FusionTensorAxes((dual(g2),), (g3,)))
m3 = FusionTensor{Float64}(undef, (g1,), (g3,))

mul!(ft3, ft1, ft2, 1.0, 1.0)
@test isnothing(check_sanity(ft2))
@test domain_axes(ft3) === domain_axes(ft2)
@test codomain_axes(ft3) === codomain_axes(ft1)
mul!(m3, m1, m2, 2.0, 0.0)
@test m3 ≈ 2m1 * m2
end

@testset "TensorAlgebra interface" begin
Expand All @@ -43,19 +65,22 @@ end
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])

ft1 = FusionTensor{Float64}(undef, (g1, g2), (g3, g4))
ft2 = FusionTensor{Float64}(undef, dual.((g3, g4)), (dual(g1),))
ft3 = FusionTensor{Float64}(undef, dual.((g3, g4)), dual.((g1, g2)))
ft1 = randn(FusionTensorAxes((g1, g2), (g3, g4)))
ft2 = randn(FusionTensorAxes(dual.((g3, g4)), (dual(g1),)))
ft3 = randn(FusionTensorAxes(dual.((g3, g4)), dual.((g1, g2))))

ft4, legs = contract(ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
@test legs == tuplemortar(((1, 2), (5,)))
@test isnothing(check_sanity(ft4))
@test domain_axes(ft4) === domain_axes(ft2)
@test codomain_axes(ft4) === codomain_axes(ft1)
@test ft4 ≈ ft1 * ft2

ft5 = contract((1, 2, 5), ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
@test isnothing(check_sanity(ft5))
@test ft4 ≈ ft5
@test ndims_codomain(ft5) === 3
@test ndims_domain(ft5) === 0
@test permutedims(ft5, (1, 2), (3,)) ≈ ft4

ft6 = contract(tuplemortar(((1, 2), (5,))), ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
@test isnothing(check_sanity(ft6))
Expand All @@ -66,4 +91,18 @@ end
ft7, legs = contract(ft1, (1, 2, 3, 4), ft3, (3, 4, 1, 2))
@test legs == tuplemortar(((), ()))
@test ft7 isa FusionTensor{Float64,0}

# include permutations
ft6 = contract(tuplemortar(((5, 1), (2,))), ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
@test isnothing(check_sanity(ft6))
@test permutedims(ft6, (2, 3), (1,)) ≈ ft4

ft8 = contract(
tuplemortar(((-3,), (-1, -2, -4))), ft1, (-1, 1, -2, 2), ft3, (-3, 2, -4, 1)
)
left = permutedims(ft1, (1, 3), (2, 4))
right = permutedims(ft3, (4, 2), (1, 3))
lrprod = left * right
newft = permutedims(lrprod, (3,), (1, 2, 4))
@test newft ≈ ft8
end
Loading
Loading