diff --git a/Project.toml b/Project.toml index 2d9e5ab..d5f94a7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FusionTensors" uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e" authors = ["ITensor developers and contributors"] -version = "0.5.0" +version = "0.5.1" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -22,7 +22,7 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b" Accessors = "0.1.42" BlockArrays = "1.6" BlockSparseArrays = "0.7.4" -GradedArrays = "0.4.13" +GradedArrays = "0.4.14" HalfIntegers = "1.6" LRUCache = "1.6" LinearAlgebra = "1.10" diff --git a/src/fusiontensor/base_interface.jl b/src/fusiontensor/base_interface.jl index 5befd60..4def6f8 100644 --- a/src/fusiontensor/base_interface.jl +++ b/src/fusiontensor/base_interface.jl @@ -2,6 +2,7 @@ using Accessors: @set using BlockSparseArrays: @view!, eachstoredblock +using GradedArrays: checkspaces, checkspaces_dual using TensorAlgebra: BlockedTuple, tuplemortar set_data_matrix(ft::FusionTensor, data_matrix) = @set ft.data_matrix = data_matrix @@ -11,21 +12,21 @@ Base.:*(ft::FusionTensor, x::Number) = x * ft # tensor contraction is a block data_matrix product. function Base.:*(left::FusionTensor, right::FusionTensor) - checkaxes_dual(domain_axes(left), codomain_axes(right)) + checkspaces_dual(domain_axes(left), codomain_axes(right)) new_data_matrix = data_matrix(left) * data_matrix(right) return FusionTensor(new_data_matrix, codomain_axes(left), domain_axes(right)) end # tensor addition is a block data_matrix add. function Base.:+(left::FusionTensor, right::FusionTensor) - checkaxes(axes(left), axes(right)) + checkspaces(axes(left), axes(right)) return set_data_matrix(left, data_matrix(left) + data_matrix(right)) end Base.:-(ft::FusionTensor) = set_data_matrix(ft, -data_matrix(ft)) function Base.:-(left::FusionTensor, right::FusionTensor) - checkaxes(axes(left), axes(right)) + checkspaces(axes(left), axes(right)) return set_data_matrix(left, data_matrix(left) - data_matrix(right)) end diff --git a/src/fusiontensor/fusiontensor.jl b/src/fusiontensor/fusiontensor.jl index 490644b..6b6b76b 100644 --- a/src/fusiontensor/fusiontensor.jl +++ b/src/fusiontensor/fusiontensor.jl @@ -9,6 +9,7 @@ using GradedArrays: SectorProduct, TrivialSector, dual, + findfirstblock, flip, flip_dual, gradedrange, @@ -27,23 +28,6 @@ using TypeParameterAccessors: type_parameters # ======================================= Misc =========================================== -# TBD move to GradedArrays? rename findfirst_sector? -function find_sector_block(s::AbstractSector, g::AbstractGradedUnitRange) - return findfirst(==(s), sectors(flip_dual(g))) -end - -# TBD move to GradedArrays? -function checkaxes(::Type{Bool}, axes1, axes2) - return length(axes1) == length(axes2) && all(space_isequal.(axes1, axes2)) -end - -# TBD move to GradedArrays? -checkaxes_dual(axes1, axes2) = checkaxes(axes1, dual.(axes2)) -function checkaxes(ax1, ax2) - return checkaxes(Bool, ax1, ax2) || - throw(DimensionMismatch(lazy"$ax1 does not match $ax2")) -end - function to_blockindexrange(b1::BlockIndexRange{1}, b2::BlockIndexRange{1}) t = (b1, b2) return Block(Block.(t))[to_block_indices.(t)...] @@ -260,9 +244,9 @@ end function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree) # find outer block corresponding to fusion trees @assert typeof((f1, f2)) === keytype(trees_block_mapping(ft)) - b1 = find_sector_block.(leaves(f1), codomain_axes(ft)) - b2 = find_sector_block.(leaves(f2), domain_axes(ft)) - return Block(b1..., b2...) + b1 = findfirstblock.(flip_dual.(codomain_axes(ft)), leaves(f1)) + b2 = findfirstblock.(flip_dual.(domain_axes(ft)), leaves(f2)) + return Block(Int.(b1)..., Int.(b2)...) end # ============================== GradedArrays interface ================================== diff --git a/src/fusiontensor/fusiontensoraxes.jl b/src/fusiontensor/fusiontensoraxes.jl index bf3a333..d86f8c8 100644 --- a/src/fusiontensor/fusiontensoraxes.jl +++ b/src/fusiontensor/fusiontensoraxes.jl @@ -110,6 +110,12 @@ function GradedArrays.sector_type(::Type{FTA}) where {BT,FTA<:FusionTensorAxes{B return sector_type(type_parameters(type_parameters(BT, 3), 1)) end +function GradedArrays.checkspaces( + ::Type{Bool}, left::FusionTensorAxes, right::FusionTensorAxes +) + return left == right +end + # ============================== FusionTensor interface ================================== codomain(fta::FusionTensorAxes) = fta[Block(1)] diff --git a/src/fusiontensor/linear_algebra_interface.jl b/src/fusiontensor/linear_algebra_interface.jl index 4a631e5..68cbb70 100644 --- a/src/fusiontensor/linear_algebra_interface.jl +++ b/src/fusiontensor/linear_algebra_interface.jl @@ -5,7 +5,7 @@ using LinearAlgebra: LinearAlgebra, mul!, norm, tr using BlockArrays: Block, blocks using BlockSparseArrays: eachblockstoredindex -using GradedArrays: quantum_dimension, sectors +using GradedArrays: checkspaces, checkspaces_dual, quantum_dimension, sectors # allow to contract with different eltype and let BlockSparseArray ensure compatibility # impose matching type and number of axes at compile time @@ -27,9 +27,9 @@ function LinearAlgebra.mul!( end # input validation - checkaxes_dual(domain_axes(A), codomain_axes(B)) - checkaxes(codomain_axes(C), codomain_axes(A)) - checkaxes(domain_axes(C), domain_axes(B)) + checkspaces_dual(domain_axes(A), codomain_axes(B)) + checkspaces(codomain_axes(C), codomain_axes(A)) + checkspaces(domain_axes(C), domain_axes(B)) mul!(data_matrix(C), data_matrix(A), data_matrix(B), α, β) return C end diff --git a/test/setup.jl b/test/setup.jl index 0607a81..b36de0e 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -6,13 +6,12 @@ using FusionTensors: codomain_axes, data_matrix, domain_axes, - checkaxes, - checkaxes_dual, domain_axis, codomain_axis, ndims_codomain, ndims_domain -using GradedArrays: dual, sectors, sector_multiplicities, space_isequal +using GradedArrays: + checkspaces, checkspaces_dual, dual, sectors, sector_multiplicities, space_isequal function check_sanity(ft::FusionTensor) nca = ndims_domain(ft) @@ -25,8 +24,8 @@ function check_sanity(ft::FusionTensor) @assert nda + nca == ndims(ft) "invalid ndims" @assert length(axes(ft)) == ndims(ft) "ndims does not match axes" - checkaxes(axes(ft)[begin:nda], codomain_axes(ft)) - checkaxes(axes(ft)[(nda + 1):end], domain_axes(ft)) + checkspaces(axes(ft)[begin:nda], codomain_axes(ft)) + checkspaces(axes(ft)[(nda + 1):end], domain_axes(ft)) m = data_matrix(ft) @assert ndims(m) == 2 "invalid data_matrix ndims" diff --git a/test/test_basics.jl b/test/test_basics.jl index 7e076b8..295cfc0 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -9,8 +9,6 @@ using FusionTensors: data_matrix, domain_axes, FusionTensor, - checkaxes, - checkaxes_dual, codomain_axis, domain_axis, ndims_domain, @@ -21,6 +19,8 @@ using GradedArrays: SectorProduct, TrivialSector, Z, + checkspaces, + checkspaces_dual, dual, flip, gradedrange, @@ -54,11 +54,11 @@ include("setup.jl") # getters @test data_matrix(ft1) == m - @test checkaxes(axes(ft1), tuplemortar(((g1,), (g2,)))) + @test axes(ft1) == FusionTensorAxes((g1,), (g2,)) # misc - @test checkaxes(codomain_axes(ft1), (g1,)) - @test checkaxes(domain_axes(ft1), (g2,)) + @test checkspaces(codomain_axes(ft1), (g1,)) + @test checkspaces(domain_axes(ft1), (g2,)) @test ndims_codomain(ft1) == 1 @test ndims_domain(ft1) == 1 @test size(data_matrix(ft1)) == (6, 5) @@ -86,42 +86,42 @@ include("setup.jl") @test ft2 !== ft1 @test data_matrix(ft2) == data_matrix(ft1) @test data_matrix(ft2) !== data_matrix(ft1) - @test checkaxes(codomain_axes(ft2), codomain_axes(ft1)) - @test checkaxes(domain_axes(ft2), domain_axes(ft1)) + @test checkspaces(codomain_axes(ft2), codomain_axes(ft1)) + @test checkspaces(domain_axes(ft2), domain_axes(ft1)) ft2 = deepcopy(ft1) @test ft2 !== ft1 @test data_matrix(ft2) == data_matrix(ft1) @test data_matrix(ft2) !== data_matrix(ft1) - @test checkaxes(codomain_axes(ft2), codomain_axes(ft1)) - @test checkaxes(domain_axes(ft2), domain_axes(ft1)) + @test checkspaces(codomain_axes(ft2), codomain_axes(ft1)) + @test checkspaces(domain_axes(ft2), domain_axes(ft1)) # similar ft2 = similar(ft1) @test isnothing(check_sanity(ft2)) @test eltype(ft2) == Float64 - @test checkaxes(codomain_axes(ft2), codomain_axes(ft1)) - @test checkaxes(domain_axes(ft2), domain_axes(ft1)) + @test checkspaces(codomain_axes(ft2), codomain_axes(ft1)) + @test checkspaces(domain_axes(ft2), domain_axes(ft1)) ft3 = similar(ft1, ComplexF64) @test isnothing(check_sanity(ft3)) @test eltype(ft3) == ComplexF64 - @test checkaxes(codomain_axes(ft3), codomain_axes(ft1)) - @test checkaxes(domain_axes(ft3), domain_axes(ft1)) + @test checkspaces(codomain_axes(ft3), codomain_axes(ft1)) + @test checkspaces(domain_axes(ft3), domain_axes(ft1)) @test_throws AssertionError similar(ft1, Int) ft5 = similar(ft1, ComplexF32, ((g1, g1), (g2,))) @test isnothing(check_sanity(ft5)) @test eltype(ft5) == ComplexF64 - @test checkaxes(codomain_axes(ft5), (g1, g1)) - @test checkaxes(domain_axes(ft5), (g2,)) + @test checkspaces(codomain_axes(ft5), (g1, g1)) + @test checkspaces(domain_axes(ft5), (g2,)) ft5 = similar(ft1, ComplexF32, tuplemortar(((g1, g1), (g2,)))) @test isnothing(check_sanity(ft5)) @test eltype(ft5) == ComplexF64 - @test checkaxes(codomain_axes(ft5), (g1, g1)) - @test checkaxes(domain_axes(ft5), (g2,)) + @test checkspaces(codomain_axes(ft5), (g1, g1)) + @test checkspaces(domain_axes(ft5), (g2,)) end @testset "More than 2 axes" begin @@ -135,8 +135,8 @@ end ft = FusionTensor(m2, (g1, g2), (g3, g4)) @test data_matrix(ft) == m2 - @test checkaxes(codomain_axes(ft), (g1, g2)) - @test checkaxes(domain_axes(ft), (g3, g4)) + @test checkspaces(codomain_axes(ft), (g1, g2)) + @test checkspaces(domain_axes(ft), (g3, g4)) @test axes(ft) == FusionTensorAxes(tuplemortar(((g1, g2), (g3, g4)))) @test ndims_codomain(ft) == 2 @@ -269,9 +269,9 @@ end @test isnothing(check_sanity(ad)) ft7 = FusionTensor{Float64}(undef, (g1,), (g2, g3, g4)) - @test_throws DimensionMismatch ft7 + ft3 - @test_throws DimensionMismatch ft7 - ft3 - @test_throws DimensionMismatch ft7 * ft3 + @test_throws ArgumentError ft7 + ft3 + @test_throws ArgumentError ft7 - ft3 + @test_throws ArgumentError ft7 * ft3 end @testset "specific constructors" begin diff --git a/test/test_fusiontensoraxes.jl b/test/test_fusiontensoraxes.jl index 4a1cbeb..6ac3f4d 100644 --- a/test/test_fusiontensoraxes.jl +++ b/test/test_fusiontensoraxes.jl @@ -1,4 +1,4 @@ -using Test: @test, @testset +using Test: @test, @test_throws, @testset using TensorProducts: ⊗ using BlockArrays: Block, blockedrange, blocklength, blocklengths, blocks @@ -15,7 +15,16 @@ using FusionTensors: promote_sector_type, promote_sectors using GradedArrays: - ×, U1, SectorProduct, TrivialSector, SU2, dual, gradedrange, sector_type, space_isequal + ×, + U1, + SectorProduct, + TrivialSector, + SU2, + checkspaces, + dual, + gradedrange, + sector_type, + space_isequal @testset "misc FusionTensors.jl" begin g1 = gradedrange([U1(0) => 1]) @@ -83,6 +92,10 @@ end @test fta != FusionTensorAxes(tuplemortar(((g2, g2, g2b), (g2b,)))) @test fta == FusionTensorAxes((g2, g2), (g2b, g2b)) + @test checkspaces(fta, fta) + @test_throws ArgumentError checkspaces( + fta, FusionTensorAxes(tuplemortar(((g2, g2), (g2b, g2)))) + ) end @testset "Empty FusionTensorAxes" begin diff --git a/test/test_permutedims.jl b/test/test_permutedims.jl index fd293a5..2444026 100644 --- a/test/test_permutedims.jl +++ b/test/test_permutedims.jl @@ -2,8 +2,8 @@ using Test: @test, @testset, @test_broken, @test_throws using FusionTensors: FusionTensor, + FusionTensorAxes, data_matrix, - checkaxes, codomain_axis, domain_axis, naive_permutedims, @@ -40,11 +40,11 @@ include("setup.jl") ft3 = permutedims(ft1, (4,), (1, 2, 3)) @test ft3 !== ft1 @test ft3 isa FusionTensor{elt,4} - @test checkaxes(axes(ft3), tuplemortar(((dual(g4),), (g1, g2, dual(g3))))) + @test axes(ft3) == FusionTensorAxes((dual(g4),), (g1, g2, dual(g3))) @test isnothing(check_sanity(ft3)) ft4 = permutedims(ft3, (2, 3), (4, 1)) - @test checkaxes(axes(ft1), axes(ft4)) + @test axes(ft1) == axes(ft4) @test space_isequal(codomain_axis(ft1), codomain_axis(ft4)) @test space_isequal(domain_axis(ft1), domain_axis(ft4)) @test ft4 ≈ ft1