Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FusionTensors"
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.0"
version = "0.5.1"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -22,7 +22,7 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b"
Accessors = "0.1.42"
BlockArrays = "1.6"
BlockSparseArrays = "0.7.4"
GradedArrays = "0.4.13"
GradedArrays = "0.4.14"
HalfIntegers = "1.6"
LRUCache = "1.6"
LinearAlgebra = "1.10"
Expand Down
7 changes: 4 additions & 3 deletions src/fusiontensor/base_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using Accessors: @set
using BlockSparseArrays: @view!, eachstoredblock
using GradedArrays: checkspaces, checkspaces_dual
using TensorAlgebra: BlockedTuple, tuplemortar

set_data_matrix(ft::FusionTensor, data_matrix) = @set ft.data_matrix = data_matrix
Expand All @@ -11,21 +12,21 @@ Base.:*(ft::FusionTensor, x::Number) = x * ft

# tensor contraction is a block data_matrix product.
function Base.:*(left::FusionTensor, right::FusionTensor)
checkaxes_dual(domain_axes(left), codomain_axes(right))
checkspaces_dual(domain_axes(left), codomain_axes(right))
new_data_matrix = data_matrix(left) * data_matrix(right)
return FusionTensor(new_data_matrix, codomain_axes(left), domain_axes(right))
end

# tensor addition is a block data_matrix add.
function Base.:+(left::FusionTensor, right::FusionTensor)
checkaxes(axes(left), axes(right))
checkspaces(axes(left), axes(right))
return set_data_matrix(left, data_matrix(left) + data_matrix(right))
end

Base.:-(ft::FusionTensor) = set_data_matrix(ft, -data_matrix(ft))

function Base.:-(left::FusionTensor, right::FusionTensor)
checkaxes(axes(left), axes(right))
checkspaces(axes(left), axes(right))
return set_data_matrix(left, data_matrix(left) - data_matrix(right))
end

Expand Down
24 changes: 4 additions & 20 deletions src/fusiontensor/fusiontensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using GradedArrays:
SectorProduct,
TrivialSector,
dual,
findfirstblock,
flip,
flip_dual,
gradedrange,
Expand All @@ -27,23 +28,6 @@ using TypeParameterAccessors: type_parameters

# ======================================= Misc ===========================================

# TBD move to GradedArrays? rename findfirst_sector?
function find_sector_block(s::AbstractSector, g::AbstractGradedUnitRange)
return findfirst(==(s), sectors(flip_dual(g)))
end

# TBD move to GradedArrays?
function checkaxes(::Type{Bool}, axes1, axes2)
return length(axes1) == length(axes2) && all(space_isequal.(axes1, axes2))
end

# TBD move to GradedArrays?
checkaxes_dual(axes1, axes2) = checkaxes(axes1, dual.(axes2))
function checkaxes(ax1, ax2)
return checkaxes(Bool, ax1, ax2) ||
throw(DimensionMismatch(lazy"$ax1 does not match $ax2"))
end

function to_blockindexrange(b1::BlockIndexRange{1}, b2::BlockIndexRange{1})
t = (b1, b2)
return Block(Block.(t))[to_block_indices.(t)...]
Expand Down Expand Up @@ -260,9 +244,9 @@ end
function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
# find outer block corresponding to fusion trees
@assert typeof((f1, f2)) === keytype(trees_block_mapping(ft))
b1 = find_sector_block.(leaves(f1), codomain_axes(ft))
b2 = find_sector_block.(leaves(f2), domain_axes(ft))
return Block(b1..., b2...)
b1 = findfirstblock.(flip_dual.(codomain_axes(ft)), leaves(f1))
b2 = findfirstblock.(flip_dual.(domain_axes(ft)), leaves(f2))
return Block(Int.(b1)..., Int.(b2)...)
end

# ============================== GradedArrays interface ==================================
Expand Down
6 changes: 6 additions & 0 deletions src/fusiontensor/fusiontensoraxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ function GradedArrays.sector_type(::Type{FTA}) where {BT,FTA<:FusionTensorAxes{B
return sector_type(type_parameters(type_parameters(BT, 3), 1))
end

function GradedArrays.checkspaces(
::Type{Bool}, left::FusionTensorAxes, right::FusionTensorAxes
)
return left == right
end

# ============================== FusionTensor interface ==================================

codomain(fta::FusionTensorAxes) = fta[Block(1)]
Expand Down
8 changes: 4 additions & 4 deletions src/fusiontensor/linear_algebra_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using LinearAlgebra: LinearAlgebra, mul!, norm, tr
using BlockArrays: Block, blocks

using BlockSparseArrays: eachblockstoredindex
using GradedArrays: quantum_dimension, sectors
using GradedArrays: checkspaces, checkspaces_dual, quantum_dimension, sectors

# allow to contract with different eltype and let BlockSparseArray ensure compatibility
# impose matching type and number of axes at compile time
Expand All @@ -27,9 +27,9 @@ function LinearAlgebra.mul!(
end

# input validation
checkaxes_dual(domain_axes(A), codomain_axes(B))
checkaxes(codomain_axes(C), codomain_axes(A))
checkaxes(domain_axes(C), domain_axes(B))
checkspaces_dual(domain_axes(A), codomain_axes(B))
checkspaces(codomain_axes(C), codomain_axes(A))
checkspaces(domain_axes(C), domain_axes(B))
mul!(data_matrix(C), data_matrix(A), data_matrix(B), α, β)
return C
end
Expand Down
9 changes: 4 additions & 5 deletions test/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ using FusionTensors:
codomain_axes,
data_matrix,
domain_axes,
checkaxes,
checkaxes_dual,
domain_axis,
codomain_axis,
ndims_codomain,
ndims_domain
using GradedArrays: dual, sectors, sector_multiplicities, space_isequal
using GradedArrays:
checkspaces, checkspaces_dual, dual, sectors, sector_multiplicities, space_isequal

function check_sanity(ft::FusionTensor)
nca = ndims_domain(ft)
Expand All @@ -25,8 +24,8 @@ function check_sanity(ft::FusionTensor)
@assert nda + nca == ndims(ft) "invalid ndims"

@assert length(axes(ft)) == ndims(ft) "ndims does not match axes"
checkaxes(axes(ft)[begin:nda], codomain_axes(ft))
checkaxes(axes(ft)[(nda + 1):end], domain_axes(ft))
checkspaces(axes(ft)[begin:nda], codomain_axes(ft))
checkspaces(axes(ft)[(nda + 1):end], domain_axes(ft))

m = data_matrix(ft)
@assert ndims(m) == 2 "invalid data_matrix ndims"
Expand Down
44 changes: 22 additions & 22 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ using FusionTensors:
data_matrix,
domain_axes,
FusionTensor,
checkaxes,
checkaxes_dual,
codomain_axis,
domain_axis,
ndims_domain,
Expand All @@ -21,6 +19,8 @@ using GradedArrays:
SectorProduct,
TrivialSector,
Z,
checkspaces,
checkspaces_dual,
dual,
flip,
gradedrange,
Expand Down Expand Up @@ -54,11 +54,11 @@ include("setup.jl")

# getters
@test data_matrix(ft1) == m
@test checkaxes(axes(ft1), tuplemortar(((g1,), (g2,))))
@test axes(ft1) == FusionTensorAxes((g1,), (g2,))

# misc
@test checkaxes(codomain_axes(ft1), (g1,))
@test checkaxes(domain_axes(ft1), (g2,))
@test checkspaces(codomain_axes(ft1), (g1,))
@test checkspaces(domain_axes(ft1), (g2,))
@test ndims_codomain(ft1) == 1
@test ndims_domain(ft1) == 1
@test size(data_matrix(ft1)) == (6, 5)
Expand Down Expand Up @@ -86,42 +86,42 @@ include("setup.jl")
@test ft2 !== ft1
@test data_matrix(ft2) == data_matrix(ft1)
@test data_matrix(ft2) !== data_matrix(ft1)
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
@test checkspaces(codomain_axes(ft2), codomain_axes(ft1))
@test checkspaces(domain_axes(ft2), domain_axes(ft1))

ft2 = deepcopy(ft1)
@test ft2 !== ft1
@test data_matrix(ft2) == data_matrix(ft1)
@test data_matrix(ft2) !== data_matrix(ft1)
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
@test checkspaces(codomain_axes(ft2), codomain_axes(ft1))
@test checkspaces(domain_axes(ft2), domain_axes(ft1))

# similar
ft2 = similar(ft1)
@test isnothing(check_sanity(ft2))
@test eltype(ft2) == Float64
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
@test checkspaces(codomain_axes(ft2), codomain_axes(ft1))
@test checkspaces(domain_axes(ft2), domain_axes(ft1))

ft3 = similar(ft1, ComplexF64)
@test isnothing(check_sanity(ft3))
@test eltype(ft3) == ComplexF64
@test checkaxes(codomain_axes(ft3), codomain_axes(ft1))
@test checkaxes(domain_axes(ft3), domain_axes(ft1))
@test checkspaces(codomain_axes(ft3), codomain_axes(ft1))
@test checkspaces(domain_axes(ft3), domain_axes(ft1))

@test_throws AssertionError similar(ft1, Int)

ft5 = similar(ft1, ComplexF32, ((g1, g1), (g2,)))
@test isnothing(check_sanity(ft5))
@test eltype(ft5) == ComplexF64
@test checkaxes(codomain_axes(ft5), (g1, g1))
@test checkaxes(domain_axes(ft5), (g2,))
@test checkspaces(codomain_axes(ft5), (g1, g1))
@test checkspaces(domain_axes(ft5), (g2,))

ft5 = similar(ft1, ComplexF32, tuplemortar(((g1, g1), (g2,))))
@test isnothing(check_sanity(ft5))
@test eltype(ft5) == ComplexF64
@test checkaxes(codomain_axes(ft5), (g1, g1))
@test checkaxes(domain_axes(ft5), (g2,))
@test checkspaces(codomain_axes(ft5), (g1, g1))
@test checkspaces(domain_axes(ft5), (g2,))
end

@testset "More than 2 axes" begin
Expand All @@ -135,8 +135,8 @@ end
ft = FusionTensor(m2, (g1, g2), (g3, g4))

@test data_matrix(ft) == m2
@test checkaxes(codomain_axes(ft), (g1, g2))
@test checkaxes(domain_axes(ft), (g3, g4))
@test checkspaces(codomain_axes(ft), (g1, g2))
@test checkspaces(domain_axes(ft), (g3, g4))

@test axes(ft) == FusionTensorAxes(tuplemortar(((g1, g2), (g3, g4))))
@test ndims_codomain(ft) == 2
Expand Down Expand Up @@ -269,9 +269,9 @@ end
@test isnothing(check_sanity(ad))

ft7 = FusionTensor{Float64}(undef, (g1,), (g2, g3, g4))
@test_throws DimensionMismatch ft7 + ft3
@test_throws DimensionMismatch ft7 - ft3
@test_throws DimensionMismatch ft7 * ft3
@test_throws ArgumentError ft7 + ft3
@test_throws ArgumentError ft7 - ft3
@test_throws ArgumentError ft7 * ft3
end

@testset "specific constructors" begin
Expand Down
17 changes: 15 additions & 2 deletions test/test_fusiontensoraxes.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Test: @test, @testset
using Test: @test, @test_throws, @testset

using TensorProducts: ⊗
using BlockArrays: Block, blockedrange, blocklength, blocklengths, blocks
Expand All @@ -15,7 +15,16 @@ using FusionTensors:
promote_sector_type,
promote_sectors
using GradedArrays:
×, U1, SectorProduct, TrivialSector, SU2, dual, gradedrange, sector_type, space_isequal
×,
U1,
SectorProduct,
TrivialSector,
SU2,
checkspaces,
dual,
gradedrange,
sector_type,
space_isequal

@testset "misc FusionTensors.jl" begin
g1 = gradedrange([U1(0) => 1])
Expand Down Expand Up @@ -83,6 +92,10 @@ end
@test fta != FusionTensorAxes(tuplemortar(((g2, g2, g2b), (g2b,))))

@test fta == FusionTensorAxes((g2, g2), (g2b, g2b))
@test checkspaces(fta, fta)
@test_throws ArgumentError checkspaces(
fta, FusionTensorAxes(tuplemortar(((g2, g2), (g2b, g2))))
)
end

@testset "Empty FusionTensorAxes" begin
Expand Down
6 changes: 3 additions & 3 deletions test/test_permutedims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ using Test: @test, @testset, @test_broken, @test_throws

using FusionTensors:
FusionTensor,
FusionTensorAxes,
data_matrix,
checkaxes,
codomain_axis,
domain_axis,
naive_permutedims,
Expand Down Expand Up @@ -40,11 +40,11 @@ include("setup.jl")
ft3 = permutedims(ft1, (4,), (1, 2, 3))
@test ft3 !== ft1
@test ft3 isa FusionTensor{elt,4}
@test checkaxes(axes(ft3), tuplemortar(((dual(g4),), (g1, g2, dual(g3)))))
@test axes(ft3) == FusionTensorAxes((dual(g4),), (g1, g2, dual(g3)))
@test isnothing(check_sanity(ft3))

ft4 = permutedims(ft3, (2, 3), (4, 1))
@test checkaxes(axes(ft1), axes(ft4))
@test axes(ft1) == axes(ft4)
@test space_isequal(codomain_axis(ft1), codomain_axis(ft4))
@test space_isequal(domain_axis(ft1), domain_axis(ft4))
@test ft4 ≈ ft1
Expand Down
Loading