diff --git a/Project.toml b/Project.toml index 79d953e..6deb6f6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FusionTensors" uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e" authors = ["ITensor developers and contributors"] -version = "0.5.7" +version = "0.5.8" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -28,7 +28,7 @@ LRUCache = "1.6" LinearAlgebra = "1.10" Random = "1.10" Strided = "2.3" -TensorAlgebra = "0.3.8" +TensorAlgebra = "0.3.12" TensorProducts = "0.1.7" TypeParameterAccessors = "0.4" WignerSymbols = "2.0.0" diff --git a/src/fusiontensor/base_interface.jl b/src/fusiontensor/base_interface.jl index e3af133..36b7c89 100644 --- a/src/fusiontensor/base_interface.jl +++ b/src/fusiontensor/base_interface.jl @@ -122,10 +122,13 @@ function Base.similar(::FusionTensor, ::Type{T}, new_axes::FusionTensorAxes) whe return FusionTensor{T}(undef, new_axes) end -Base.show(io::IO, ft::FusionTensor) = print(io, "$(ndims(ft))-dim FusionTensor") +function Base.show(io::IO, ft::FusionTensor) + return print(io, "$(ndims(ft))-dim FusionTensor with size $(size(ft))") +end function Base.show(io::IO, ::MIME"text/plain", ft::FusionTensor) - print(io, "$(ndims(ft))-dim FusionTensor with axes:") + print(io, ft) + print(" and axes:") for ax in axes(ft) print(io, "\n", ax) end diff --git a/src/fusiontensor/fusiontensor.jl b/src/fusiontensor/fusiontensor.jl index d3e60d4..6f9d93d 100644 --- a/src/fusiontensor/fusiontensor.jl +++ b/src/fusiontensor/fusiontensor.jl @@ -22,7 +22,7 @@ using GradedArrays: space_isequal using LinearAlgebra: UniformScaling using Random: Random, AbstractRNG, randn! -using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar +using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar, length_codomain, length_domain using TensorProducts: tensor_product using TypeParameterAccessors: type_parameters @@ -134,6 +134,10 @@ struct FusionTensor{T,N,Axes<:FusionTensorAxes,Mat<:AbstractMatrix{T},Mapping} < end end +const FusionMatrix{T,Axes,Mat,Mapping} = FusionTensor{ + T,2,Axes,Mapping +} where {BT<:BlockedTuple{2,(1, 1)},Axes<:FusionTensorAxes{BT}} + # ===================================== Accessors ======================================== data_matrix(ft::FusionTensor) = ft.data_matrix diff --git a/src/fusiontensor/fusiontensoraxes.jl b/src/fusiontensor/fusiontensoraxes.jl index f3e5712..46d2478 100644 --- a/src/fusiontensor/fusiontensoraxes.jl +++ b/src/fusiontensor/fusiontensoraxes.jl @@ -8,7 +8,13 @@ using GradedArrays: dual, sector_type, trivial -using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTuple +using TensorAlgebra: + TensorAlgebra, + AbstractBlockPermutation, + AbstractBlockTuple, + BlockedTuple, + length_codomain, + length_domain using TensorProducts: ⊗ using TypeParameterAccessors: type_parameters @@ -65,6 +71,8 @@ TensorAlgebra.BlockedTuple(fta::FusionTensorAxes) = fta.outer_axes TensorAlgebra.trivial_axis(fta::FusionTensorAxes) = trivial_axis(sector_type(fta)) +TensorAlgebra.length_domain(fta::FusionTensorAxes) = length(domain(fta)) + # ================================== Base interface ====================================== for f in [ @@ -140,7 +148,3 @@ function fused_domain(fta::FusionTensorAxes) end return dual(⊗(dual.(domain(fta))...)) end - -length_codomain(fta::FusionTensorAxes) = length(codomain(fta)) - -length_domain(fta::FusionTensorAxes) = length(domain(fta)) diff --git a/src/fusiontensor/linear_algebra_interface.jl b/src/fusiontensor/linear_algebra_interface.jl index 68cbb70..b63bc03 100644 --- a/src/fusiontensor/linear_algebra_interface.jl +++ b/src/fusiontensor/linear_algebra_interface.jl @@ -10,9 +10,8 @@ using GradedArrays: checkspaces, checkspaces_dual, quantum_dimension, sectors # allow to contract with different eltype and let BlockSparseArray ensure compatibility # impose matching type and number of axes at compile time # impose matching axes at run time -# TODO remove this once TensorAlgebra.contract can be used? function LinearAlgebra.mul!( - C::FusionTensor, A::FusionTensor, B::FusionTensor, α::Number, β::Number + C::FusionMatrix, A::FusionMatrix, B::FusionMatrix, α::Number, β::Number ) # compile time checks diff --git a/src/fusiontensor/tensor_algebra_interface.jl b/src/fusiontensor/tensor_algebra_interface.jl index 3129fa5..b75ba09 100644 --- a/src/fusiontensor/tensor_algebra_interface.jl +++ b/src/fusiontensor/tensor_algebra_interface.jl @@ -4,47 +4,70 @@ using LinearAlgebra: mul! using BlockArrays: Block -using TensorAlgebra: BlockedPermutation, Matricize, TensorAlgebra +using GradedArrays: space_isequal +using TensorAlgebra: + TensorAlgebra, + AbstractBlockPermutation, + BlockedTrivialPermutation, + BlockedTuple, + FusionStyle, + Matricize, + blockedperm, + genperm, + unmatricize -# TODO how to deal with inner contraction = no ouput axis? -# => currently biperm_dest is a BlockedPermutation{0}, change this -function TensorAlgebra.allocate_output( +function TensorAlgebra.output_axes( ::typeof(contract), - biperm_dest::BlockedPermutation{2}, + biperm_dest::AbstractBlockPermutation{2}, a1::FusionTensor, - biperm1::BlockedPermutation{2}, + biperm1::AbstractBlockPermutation{2}, a2::FusionTensor, - biperm2::BlockedPermutation{2}, - α::Number=true, + biperm2::AbstractBlockPermutation{2}, + α::Number=one(Bool), ) - axes_dest = ( - map(i -> axes(a1)[i], first(blocks(biperm1))), - map(i -> axes(a2)[i], last(blocks(biperm2))), + axes_codomain, axes_contracted = blocks(axes(a1)[biperm1]) + axes_contracted2, axes_domain = blocks(axes(a2)[biperm2]) + @assert all(space_isequal.(dual.(axes_contracted), axes_contracted2)) + flat_axes = genperm((axes_codomain..., axes_domain...), Tuple(biperm_dest)) + return FusionTensorAxes( + tuplemortar(( + flat_axes[begin:length_codomain(biperm_dest)], + flat_axes[(length_codomain(biperm_dest) + 1):end], + )), ) - return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest) end -# TBD do really I need to define these as I cannot use them in contract! and has to redefine it? -#TensorAlgebra.fusedims(ft::FusionTensor, perm::BlockedPermutation{2}) = permutedims(ft, perm) -#function TensorAlgebra.splitdims(ft1::FusionTensor, ft2::FusionTensor, blockedperm::BlockedPermutation) -#function TensorAlgebra.splitdims!(ft1::FusionTensor, ft2::FusionTensor, blockedperm::BlockedPermutation) - -# I cannot use contract! from TensorAlgebra/src/contract/contract_matricize/contract.jl -# as it calls _mul!, which I should not overload. -# TBD define fallback _mul!(::AbstractArray, ::AbstractArray, ::AbstractArray) in TensorAlgebra? -function TensorAlgebra.contract!( - ::Matricize, - a_dest::FusionTensor, - ::BlockedPermutation{2}, - a1::FusionTensor, - biperm1::BlockedPermutation{2}, - a2::FusionTensor, - biperm2::BlockedPermutation{2}, - α::Number, - β::Number, +struct FusionTensorFusionStyle <: FusionStyle end + +TensorAlgebra.FusionStyle(::Type{<:FusionTensor}) = FusionTensorFusionStyle() + +function TensorAlgebra.matricize( + ::FusionTensorFusionStyle, ft::AbstractArray, biperm::BlockedTrivialPermutation{2} +) + blocklengths(biperm) == blocklengths(axes(ft)) || + throw(ArgumentError("Invalid trivial biperm")) + return FusionTensor(data_matrix(ft), (codomain_axis(ft),), (domain_axis(ft),)) +end + +function TensorAlgebra.unmatricize(::FusionTensorFusionStyle, m, blocked_axes) + return FusionTensor(data_matrix(m), blocked_axes) +end + +function TensorAlgebra.permuteblockeddims( + ft::FusionTensor, biperm::AbstractBlockPermutation ) - a1_perm = permutedims(a1, biperm1) - a2_perm = permutedims(a2, biperm2) - mul!(a_dest, a1_perm, a2_perm, α, β) + return permutedims(ft, biperm) +end + +function TensorAlgebra.permuteblockeddims!( + a::FusionTensor, b::FusionTensor, biperm::AbstractBlockPermutation +) + return permutedims!(a, b, biperm) +end + +# TODO define custom broadcast rules +function TensorAlgebra.unmatricize_add!(a_dest::FusionTensor, a_dest_mat, invbiperm, α, β) + a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm) + data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest) return a_dest end diff --git a/test/test_basics.jl b/test/test_basics.jl index a575664..44476ef 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -3,6 +3,7 @@ using Test: @test, @test_throws, @testset using BlockArrays: Block using BlockSparseArrays: BlockSparseArray, eachblockstoredindex using FusionTensors: + FusionMatrix, FusionTensor, FusionTensorAxes, codomain_axes, @@ -40,6 +41,7 @@ include("setup.jl") fta = FusionTensorAxes((g1,), (g2,)) ft0 = FusionTensor{Float64}(undef, fta) @test ft0 isa FusionTensor + @test ft0 isa FusionMatrix @test space_isequal(codomain_axis(ft0), g1) @test space_isequal(domain_axis(ft0), g2) @@ -134,6 +136,8 @@ end m2 = BlockSparseArray{Float64}(undef, gr, gc) ft = FusionTensor(m2, (g1, g2), (g3, g4)) + @test ft isa FusionTensor + @test !(ft isa FusionMatrix) @test data_matrix(ft) == m2 @test checkspaces(codomain_axes(ft), (g1, g2)) @test checkspaces(domain_axes(ft), (g3, g4)) @@ -155,6 +159,8 @@ end # one row axis ft1 = FusionTensor{Float64}(undef, (g1,), ()) + @test ft1 isa FusionTensor + @test !(ft1 isa FusionMatrix) @test ndims_codomain(ft1) == 1 @test ndims_domain(ft1) == 0 @test ndims(ft1) == 1 @@ -165,6 +171,8 @@ end # one column axis ft2 = FusionTensor{Float64}(undef, (), (g1,)) + @test ft2 isa FusionTensor + @test !(ft2 isa FusionMatrix) @test ndims_codomain(ft2) == 0 @test ndims_domain(ft2) == 1 @test ndims(ft2) == 1 @@ -175,6 +183,8 @@ end # zero axis ft3 = FusionTensor{Float64}(undef, (), ()) + @test ft3 isa FusionTensor + @test !(ft3 isa FusionMatrix) @test ndims_codomain(ft3) == 0 @test ndims_domain(ft3) == 0 @test ndims(ft3) == 0 @@ -182,6 +192,13 @@ end @test size(data_matrix(ft3)) == (1, 1) @test isnothing(check_sanity(ft3)) @test sector_type(ft3) === TrivialSector + + ft4 = FusionTensor{Float64}(undef, (g1, g1), ()) + @test ft4 isa FusionTensor + @test !(ft4 isa FusionMatrix) + ft5 = FusionTensor{Float64}(undef, (), (g1, g1)) + @test ft5 isa FusionTensor + @test !(ft5 isa FusionMatrix) end @testset "specific constructors" begin diff --git a/test/test_contraction.jl b/test/test_contraction.jl index 4317560..2b07877 100644 --- a/test/test_contraction.jl +++ b/test/test_contraction.jl @@ -2,12 +2,37 @@ using LinearAlgebra: mul! using Test: @test, @testset, @test_broken using BlockSparseArrays: BlockSparseArray -using FusionTensors: FusionTensor, domain_axes, codomain_axes +using FusionTensors: + FusionMatrix, FusionTensor, FusionTensorAxes, domain_axes, codomain_axes using GradedArrays: U1, dual, gradedrange -using TensorAlgebra: contract, tuplemortar +using TensorAlgebra: contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize! include("setup.jl") +@testset "matricize" begin + g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3]) + g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1]) + g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1]) + g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1]) + + ft1 = randn(FusionTensorAxes((g1, g2), (dual(g3), dual(g4)))) + m = matricize(ft1, (1, 2), (3, 4)) + @test m isa FusionMatrix + ft2 = unmatricize(m, axes(ft1)) + @test ft1 ≈ ft2 + + biperm = permmortar(((3,), (1, 2, 4))) + m2 = matricize(ft1, biperm) + ft_dest = FusionTensor{eltype(ft1)}(undef, axes(ft1)[biperm]) + unmatricize!(ft_dest, m2, permmortar(((1,), (2, 3, 4)))) + @test ft_dest ≈ permutedims(ft1, biperm) + @test ft_dest ≈ permutedims(ft1, biperm) + + ft2 = similar(ft1) + unmatricize!(ft2, m2, biperm) + @test ft1 ≈ ft2 +end + @testset "contraction" begin g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3]) g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1]) @@ -26,15 +51,12 @@ include("setup.jl") @test codomain_axes(ft3) === codomain_axes(ft1) # test LinearAlgebra.mul! with in-place matrix product - mul!(ft3, ft1, ft2) - @test isnothing(check_sanity(ft3)) - @test domain_axes(ft3) === domain_axes(ft2) - @test codomain_axes(ft3) === codomain_axes(ft1) + m1 = randn(FusionTensorAxes((g1,), (g2,))) + m2 = randn(FusionTensorAxes((dual(g2),), (g3,))) + m3 = FusionTensor{Float64}(undef, (g1,), (g3,)) - mul!(ft3, ft1, ft2, 1.0, 1.0) - @test isnothing(check_sanity(ft2)) - @test domain_axes(ft3) === domain_axes(ft2) - @test codomain_axes(ft3) === codomain_axes(ft1) + mul!(m3, m1, m2, 2.0, 0.0) + @test m3 ≈ 2m1 * m2 end @testset "TensorAlgebra interface" begin @@ -43,19 +65,22 @@ end g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1]) g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1]) - ft1 = FusionTensor{Float64}(undef, (g1, g2), (g3, g4)) - ft2 = FusionTensor{Float64}(undef, dual.((g3, g4)), (dual(g1),)) - ft3 = FusionTensor{Float64}(undef, dual.((g3, g4)), dual.((g1, g2))) + ft1 = randn(FusionTensorAxes((g1, g2), (g3, g4))) + ft2 = randn(FusionTensorAxes(dual.((g3, g4)), (dual(g1),))) + ft3 = randn(FusionTensorAxes(dual.((g3, g4)), dual.((g1, g2)))) ft4, legs = contract(ft1, (1, 2, 3, 4), ft2, (3, 4, 5)) @test legs == tuplemortar(((1, 2), (5,))) @test isnothing(check_sanity(ft4)) @test domain_axes(ft4) === domain_axes(ft2) @test codomain_axes(ft4) === codomain_axes(ft1) + @test ft4 ≈ ft1 * ft2 ft5 = contract((1, 2, 5), ft1, (1, 2, 3, 4), ft2, (3, 4, 5)) @test isnothing(check_sanity(ft5)) - @test ft4 ≈ ft5 + @test ndims_codomain(ft5) === 3 + @test ndims_domain(ft5) === 0 + @test permutedims(ft5, (1, 2), (3,)) ≈ ft4 ft6 = contract(tuplemortar(((1, 2), (5,))), ft1, (1, 2, 3, 4), ft2, (3, 4, 5)) @test isnothing(check_sanity(ft6)) @@ -66,4 +91,18 @@ end ft7, legs = contract(ft1, (1, 2, 3, 4), ft3, (3, 4, 1, 2)) @test legs == tuplemortar(((), ())) @test ft7 isa FusionTensor{Float64,0} + + # include permutations + ft6 = contract(tuplemortar(((5, 1), (2,))), ft1, (1, 2, 3, 4), ft2, (3, 4, 5)) + @test isnothing(check_sanity(ft6)) + @test permutedims(ft6, (2, 3), (1,)) ≈ ft4 + + ft8 = contract( + tuplemortar(((-3,), (-1, -2, -4))), ft1, (-1, 1, -2, 2), ft3, (-3, 2, -4, 1) + ) + left = permutedims(ft1, (1, 3), (2, 4)) + right = permutedims(ft3, (4, 2), (1, 3)) + lrprod = left * right + newft = permutedims(lrprod, (3,), (1, 2, 4)) + @test newft ≈ ft8 end diff --git a/test/test_fusiontensoraxes.jl b/test/test_fusiontensoraxes.jl index 6ac3f4d..f40df4e 100644 --- a/test/test_fusiontensoraxes.jl +++ b/test/test_fusiontensoraxes.jl @@ -2,7 +2,7 @@ using Test: @test, @test_throws, @testset using TensorProducts: ⊗ using BlockArrays: Block, blockedrange, blocklength, blocklengths, blocks -using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar +using TensorAlgebra: BlockedTuple, length_codomain, trivial_axis, tuplemortar using FusionTensors: FusionTensorAxes, @@ -75,6 +75,7 @@ end @test blocklength(fta) == 2 @test blocklengths(fta) == (2, 2) @test blocks(fta) == blocks(bt) + @test length_codomain(fta) == 2 @test sector_type(fta) === sector_type(g2) @test length(codomain(fta)) == 2 @@ -107,6 +108,7 @@ end @test blocklength(fta) == 2 @test blocklengths(fta) == (0, 0) @test sector_type(fta) == TrivialSector + @test length_codomain(fta) == 0 @test codomain(fta) == () @test space_isequal(fused_codomain(fta), trivial_axis(TrivialSector))