Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FusionTensors"
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.7"
version = "0.5.8"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -28,7 +28,7 @@ LRUCache = "1.6"
LinearAlgebra = "1.10"
Random = "1.10"
Strided = "2.3"
TensorAlgebra = "0.3.8"
TensorAlgebra = "0.3.12"
TensorProducts = "0.1.7"
TypeParameterAccessors = "0.4"
WignerSymbols = "2.0.0"
Expand Down
7 changes: 5 additions & 2 deletions src/fusiontensor/base_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,13 @@ function Base.similar(::FusionTensor, ::Type{T}, new_axes::FusionTensorAxes) whe
return FusionTensor{T}(undef, new_axes)
end

Base.show(io::IO, ft::FusionTensor) = print(io, "$(ndims(ft))-dim FusionTensor")
function Base.show(io::IO, ft::FusionTensor)
return print(io, "$(ndims(ft))-dim FusionTensor with size $(size(ft))")
end

function Base.show(io::IO, ::MIME"text/plain", ft::FusionTensor)
print(io, "$(ndims(ft))-dim FusionTensor with axes:")
print(io, ft)
print(" and axes:")
for ax in axes(ft)
print(io, "\n", ax)
end
Expand Down
6 changes: 5 additions & 1 deletion src/fusiontensor/fusiontensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using GradedArrays:
space_isequal
using LinearAlgebra: UniformScaling
using Random: Random, AbstractRNG, randn!
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar, length_codomain, length_domain
using TensorProducts: tensor_product
using TypeParameterAccessors: type_parameters

Expand Down Expand Up @@ -134,6 +134,10 @@ struct FusionTensor{T,N,Axes<:FusionTensorAxes,Mat<:AbstractMatrix{T},Mapping} <
end
end

const FusionMatrix{T,Axes,Mat,Mapping} = FusionTensor{
T,2,Axes,Mapping
} where {BT<:BlockedTuple{2,(1, 1)},Axes<:FusionTensorAxes{BT}}

# ===================================== Accessors ========================================

data_matrix(ft::FusionTensor) = ft.data_matrix
Expand Down
14 changes: 9 additions & 5 deletions src/fusiontensor/fusiontensoraxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ using GradedArrays:
dual,
sector_type,
trivial
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTuple
using TensorAlgebra:
TensorAlgebra,
AbstractBlockPermutation,
AbstractBlockTuple,
BlockedTuple,
length_codomain,
length_domain
using TensorProducts: ⊗
using TypeParameterAccessors: type_parameters

Expand Down Expand Up @@ -65,6 +71,8 @@ TensorAlgebra.BlockedTuple(fta::FusionTensorAxes) = fta.outer_axes

TensorAlgebra.trivial_axis(fta::FusionTensorAxes) = trivial_axis(sector_type(fta))

TensorAlgebra.length_domain(fta::FusionTensorAxes) = length(domain(fta))

# ================================== Base interface ======================================

for f in [
Expand Down Expand Up @@ -140,7 +148,3 @@ function fused_domain(fta::FusionTensorAxes)
end
return dual(⊗(dual.(domain(fta))...))
end

length_codomain(fta::FusionTensorAxes) = length(codomain(fta))

length_domain(fta::FusionTensorAxes) = length(domain(fta))
3 changes: 1 addition & 2 deletions src/fusiontensor/linear_algebra_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ using GradedArrays: checkspaces, checkspaces_dual, quantum_dimension, sectors
# allow to contract with different eltype and let BlockSparseArray ensure compatibility
# impose matching type and number of axes at compile time
# impose matching axes at run time
# TODO remove this once TensorAlgebra.contract can be used?
function LinearAlgebra.mul!(
C::FusionTensor, A::FusionTensor, B::FusionTensor, α::Number, β::Number
C::FusionMatrix, A::FusionMatrix, B::FusionMatrix, α::Number, β::Number
)

# compile time checks
Expand Down
89 changes: 56 additions & 33 deletions src/fusiontensor/tensor_algebra_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,70 @@ using LinearAlgebra: mul!

using BlockArrays: Block

using TensorAlgebra: BlockedPermutation, Matricize, TensorAlgebra
using GradedArrays: space_isequal
using TensorAlgebra:
TensorAlgebra,
AbstractBlockPermutation,
BlockedTrivialPermutation,
BlockedTuple,
FusionStyle,
Matricize,
blockedperm,
genperm,
unmatricize

# TODO how to deal with inner contraction = no ouput axis?
# => currently biperm_dest is a BlockedPermutation{0}, change this
function TensorAlgebra.allocate_output(
function TensorAlgebra.output_axes(
::typeof(contract),
biperm_dest::BlockedPermutation{2},
biperm_dest::AbstractBlockPermutation{2},
a1::FusionTensor,
biperm1::BlockedPermutation{2},
biperm1::AbstractBlockPermutation{2},
a2::FusionTensor,
biperm2::BlockedPermutation{2},
α::Number=true,
biperm2::AbstractBlockPermutation{2},
α::Number=one(Bool),
)
axes_dest = (
map(i -> axes(a1)[i], first(blocks(biperm1))),
map(i -> axes(a2)[i], last(blocks(biperm2))),
axes_codomain, axes_contracted = blocks(axes(a1)[biperm1])
axes_contracted2, axes_domain = blocks(axes(a2)[biperm2])
@assert all(space_isequal.(dual.(axes_contracted), axes_contracted2))
flat_axes = genperm((axes_codomain..., axes_domain...), Tuple(biperm_dest))
return FusionTensorAxes(
tuplemortar((
flat_axes[begin:length_codomain(biperm_dest)],
flat_axes[(length_codomain(biperm_dest) + 1):end],
)),
)
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
end

# TBD do really I need to define these as I cannot use them in contract! and has to redefine it?
#TensorAlgebra.fusedims(ft::FusionTensor, perm::BlockedPermutation{2}) = permutedims(ft, perm)
#function TensorAlgebra.splitdims(ft1::FusionTensor, ft2::FusionTensor, blockedperm::BlockedPermutation)
#function TensorAlgebra.splitdims!(ft1::FusionTensor, ft2::FusionTensor, blockedperm::BlockedPermutation)

# I cannot use contract! from TensorAlgebra/src/contract/contract_matricize/contract.jl
# as it calls _mul!, which I should not overload.
# TBD define fallback _mul!(::AbstractArray, ::AbstractArray, ::AbstractArray) in TensorAlgebra?
function TensorAlgebra.contract!(
::Matricize,
a_dest::FusionTensor,
::BlockedPermutation{2},
a1::FusionTensor,
biperm1::BlockedPermutation{2},
a2::FusionTensor,
biperm2::BlockedPermutation{2},
α::Number,
β::Number,
struct FusionTensorFusionStyle <: FusionStyle end

TensorAlgebra.FusionStyle(::Type{<:FusionTensor}) = FusionTensorFusionStyle()

function TensorAlgebra.matricize(
::FusionTensorFusionStyle, ft::AbstractArray, biperm::BlockedTrivialPermutation{2}
)
blocklengths(biperm) == blocklengths(axes(ft)) ||
throw(ArgumentError("Invalid trivial biperm"))
return FusionTensor(data_matrix(ft), (codomain_axis(ft),), (domain_axis(ft),))
end

function TensorAlgebra.unmatricize(::FusionTensorFusionStyle, m, blocked_axes)
return FusionTensor(data_matrix(m), blocked_axes)
end

function TensorAlgebra.permuteblockeddims(
ft::FusionTensor, biperm::AbstractBlockPermutation
)
a1_perm = permutedims(a1, biperm1)
a2_perm = permutedims(a2, biperm2)
mul!(a_dest, a1_perm, a2_perm, α, β)
return permutedims(ft, biperm)
end

function TensorAlgebra.permuteblockeddims!(
a::FusionTensor, b::FusionTensor, biperm::AbstractBlockPermutation
)
return permutedims!(a, b, biperm)
end

# TODO define custom broadcast rules
function TensorAlgebra.unmatricize_add!(a_dest::FusionTensor, a_dest_mat, invbiperm, α, β)
a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm)
data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest)
return a_dest
end
17 changes: 17 additions & 0 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Test: @test, @test_throws, @testset
using BlockArrays: Block
using BlockSparseArrays: BlockSparseArray, eachblockstoredindex
using FusionTensors:
FusionMatrix,
FusionTensor,
FusionTensorAxes,
codomain_axes,
Expand Down Expand Up @@ -40,6 +41,7 @@ include("setup.jl")
fta = FusionTensorAxes((g1,), (g2,))
ft0 = FusionTensor{Float64}(undef, fta)
@test ft0 isa FusionTensor
@test ft0 isa FusionMatrix
@test space_isequal(codomain_axis(ft0), g1)
@test space_isequal(domain_axis(ft0), g2)

Expand Down Expand Up @@ -134,6 +136,8 @@ end
m2 = BlockSparseArray{Float64}(undef, gr, gc)
ft = FusionTensor(m2, (g1, g2), (g3, g4))

@test ft isa FusionTensor
@test !(ft isa FusionMatrix)
@test data_matrix(ft) == m2
@test checkspaces(codomain_axes(ft), (g1, g2))
@test checkspaces(domain_axes(ft), (g3, g4))
Expand All @@ -155,6 +159,8 @@ end

# one row axis
ft1 = FusionTensor{Float64}(undef, (g1,), ())
@test ft1 isa FusionTensor
@test !(ft1 isa FusionMatrix)
@test ndims_codomain(ft1) == 1
@test ndims_domain(ft1) == 0
@test ndims(ft1) == 1
Expand All @@ -165,6 +171,8 @@ end

# one column axis
ft2 = FusionTensor{Float64}(undef, (), (g1,))
@test ft2 isa FusionTensor
@test !(ft2 isa FusionMatrix)
@test ndims_codomain(ft2) == 0
@test ndims_domain(ft2) == 1
@test ndims(ft2) == 1
Expand All @@ -175,13 +183,22 @@ end

# zero axis
ft3 = FusionTensor{Float64}(undef, (), ())
@test ft3 isa FusionTensor
@test !(ft3 isa FusionMatrix)
@test ndims_codomain(ft3) == 0
@test ndims_domain(ft3) == 0
@test ndims(ft3) == 0
@test size(ft3) == tuplemortar(((), ()))
@test size(data_matrix(ft3)) == (1, 1)
@test isnothing(check_sanity(ft3))
@test sector_type(ft3) === TrivialSector

ft4 = FusionTensor{Float64}(undef, (g1, g1), ())
@test ft4 isa FusionTensor
@test !(ft4 isa FusionMatrix)
ft5 = FusionTensor{Float64}(undef, (), (g1, g1))
@test ft5 isa FusionTensor
@test !(ft5 isa FusionMatrix)
end

@testset "specific constructors" begin
Expand Down
67 changes: 53 additions & 14 deletions test/test_contraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,37 @@ using LinearAlgebra: mul!
using Test: @test, @testset, @test_broken

using BlockSparseArrays: BlockSparseArray
using FusionTensors: FusionTensor, domain_axes, codomain_axes
using FusionTensors:
FusionMatrix, FusionTensor, FusionTensorAxes, domain_axes, codomain_axes
using GradedArrays: U1, dual, gradedrange
using TensorAlgebra: contract, tuplemortar
using TensorAlgebra: contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize!

include("setup.jl")

@testset "matricize" begin
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])

ft1 = randn(FusionTensorAxes((g1, g2), (dual(g3), dual(g4))))
m = matricize(ft1, (1, 2), (3, 4))
@test m isa FusionMatrix
ft2 = unmatricize(m, axes(ft1))
@test ft1 ≈ ft2

biperm = permmortar(((3,), (1, 2, 4)))
m2 = matricize(ft1, biperm)
ft_dest = FusionTensor{eltype(ft1)}(undef, axes(ft1)[biperm])
unmatricize!(ft_dest, m2, permmortar(((1,), (2, 3, 4))))
@test ft_dest ≈ permutedims(ft1, biperm)
@test ft_dest ≈ permutedims(ft1, biperm)

ft2 = similar(ft1)
unmatricize!(ft2, m2, biperm)
@test ft1 ≈ ft2
end

@testset "contraction" begin
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
Expand All @@ -26,15 +51,12 @@ include("setup.jl")
@test codomain_axes(ft3) === codomain_axes(ft1)

# test LinearAlgebra.mul! with in-place matrix product
mul!(ft3, ft1, ft2)
@test isnothing(check_sanity(ft3))
@test domain_axes(ft3) === domain_axes(ft2)
@test codomain_axes(ft3) === codomain_axes(ft1)
m1 = randn(FusionTensorAxes((g1,), (g2,)))
m2 = randn(FusionTensorAxes((dual(g2),), (g3,)))
m3 = FusionTensor{Float64}(undef, (g1,), (g3,))

mul!(ft3, ft1, ft2, 1.0, 1.0)
@test isnothing(check_sanity(ft2))
@test domain_axes(ft3) === domain_axes(ft2)
@test codomain_axes(ft3) === codomain_axes(ft1)
mul!(m3, m1, m2, 2.0, 0.0)
@test m3 ≈ 2m1 * m2
end

@testset "TensorAlgebra interface" begin
Expand All @@ -43,19 +65,22 @@ end
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])

ft1 = FusionTensor{Float64}(undef, (g1, g2), (g3, g4))
ft2 = FusionTensor{Float64}(undef, dual.((g3, g4)), (dual(g1),))
ft3 = FusionTensor{Float64}(undef, dual.((g3, g4)), dual.((g1, g2)))
ft1 = randn(FusionTensorAxes((g1, g2), (g3, g4)))
ft2 = randn(FusionTensorAxes(dual.((g3, g4)), (dual(g1),)))
ft3 = randn(FusionTensorAxes(dual.((g3, g4)), dual.((g1, g2))))

ft4, legs = contract(ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
@test legs == tuplemortar(((1, 2), (5,)))
@test isnothing(check_sanity(ft4))
@test domain_axes(ft4) === domain_axes(ft2)
@test codomain_axes(ft4) === codomain_axes(ft1)
@test ft4 ≈ ft1 * ft2

ft5 = contract((1, 2, 5), ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
@test isnothing(check_sanity(ft5))
@test ft4 ≈ ft5
@test ndims_codomain(ft5) === 3
@test ndims_domain(ft5) === 0
@test permutedims(ft5, (1, 2), (3,)) ≈ ft4

ft6 = contract(tuplemortar(((1, 2), (5,))), ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
@test isnothing(check_sanity(ft6))
Expand All @@ -66,4 +91,18 @@ end
ft7, legs = contract(ft1, (1, 2, 3, 4), ft3, (3, 4, 1, 2))
@test legs == tuplemortar(((), ()))
@test ft7 isa FusionTensor{Float64,0}

# include permutations
ft6 = contract(tuplemortar(((5, 1), (2,))), ft1, (1, 2, 3, 4), ft2, (3, 4, 5))
@test isnothing(check_sanity(ft6))
@test permutedims(ft6, (2, 3), (1,)) ≈ ft4

ft8 = contract(
tuplemortar(((-3,), (-1, -2, -4))), ft1, (-1, 1, -2, 2), ft3, (-3, 2, -4, 1)
)
left = permutedims(ft1, (1, 3), (2, 4))
right = permutedims(ft3, (4, 2), (1, 3))
lrprod = left * right
newft = permutedims(lrprod, (3,), (1, 2, 4))
@test newft ≈ ft8
end
4 changes: 3 additions & 1 deletion test/test_fusiontensoraxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Test: @test, @test_throws, @testset

using TensorProducts: ⊗
using BlockArrays: Block, blockedrange, blocklength, blocklengths, blocks
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar
using TensorAlgebra: BlockedTuple, length_codomain, trivial_axis, tuplemortar

using FusionTensors:
FusionTensorAxes,
Expand Down Expand Up @@ -75,6 +75,7 @@ end
@test blocklength(fta) == 2
@test blocklengths(fta) == (2, 2)
@test blocks(fta) == blocks(bt)
@test length_codomain(fta) == 2

@test sector_type(fta) === sector_type(g2)
@test length(codomain(fta)) == 2
Expand Down Expand Up @@ -107,6 +108,7 @@ end
@test blocklength(fta) == 2
@test blocklengths(fta) == (0, 0)
@test sector_type(fta) == TrivialSector
@test length_codomain(fta) == 0

@test codomain(fta) == ()
@test space_isequal(fused_codomain(fta), trivial_axis(TrivialSector))
Expand Down
Loading