Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FusionTensors"
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.0"
version = "0.5.1"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -22,7 +22,7 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b"
Accessors = "0.1.42"
BlockArrays = "1.6"
BlockSparseArrays = "0.7.4"
GradedArrays = "0.4.13"
GradedArrays = "0.4.14"
HalfIntegers = "1.6"
LRUCache = "1.6"
LinearAlgebra = "1.10"
Expand Down
7 changes: 4 additions & 3 deletions src/fusiontensor/base_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using Accessors: @set
using BlockSparseArrays: @view!, eachstoredblock
using GradedArrays: checkspaces, checkspaces_dual
using TensorAlgebra: BlockedTuple, tuplemortar

set_data_matrix(ft::FusionTensor, data_matrix) = @set ft.data_matrix = data_matrix
Expand All @@ -11,21 +12,21 @@ Base.:*(ft::FusionTensor, x::Number) = x * ft

# tensor contraction is a block data_matrix product.
function Base.:*(left::FusionTensor, right::FusionTensor)
checkaxes_dual(domain_axes(left), codomain_axes(right))
checkspaces_dual(domain_axes(left), codomain_axes(right))
new_data_matrix = data_matrix(left) * data_matrix(right)
return FusionTensor(new_data_matrix, codomain_axes(left), domain_axes(right))
end

# tensor addition is a block data_matrix add.
function Base.:+(left::FusionTensor, right::FusionTensor)
checkaxes(axes(left), axes(right))
axes(left) == axes(right) || throw(ArgumentError("Axes do not match"))
return set_data_matrix(left, data_matrix(left) + data_matrix(right))
end

Base.:-(ft::FusionTensor) = set_data_matrix(ft, -data_matrix(ft))

function Base.:-(left::FusionTensor, right::FusionTensor)
checkaxes(axes(left), axes(right))
axes(left) == axes(right) || throw(ArgumentError("Axes do not match"))
return set_data_matrix(left, data_matrix(left) - data_matrix(right))
end

Expand Down
24 changes: 4 additions & 20 deletions src/fusiontensor/fusiontensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using GradedArrays:
SectorProduct,
TrivialSector,
dual,
findfirstblock,
flip,
flip_dual,
gradedrange,
Expand All @@ -27,23 +28,6 @@ using TypeParameterAccessors: type_parameters

# ======================================= Misc ===========================================

# TBD move to GradedArrays? rename findfirst_sector?
function find_sector_block(s::AbstractSector, g::AbstractGradedUnitRange)
return findfirst(==(s), sectors(flip_dual(g)))
end

# TBD move to GradedArrays?
function checkaxes(::Type{Bool}, axes1, axes2)
return length(axes1) == length(axes2) && all(space_isequal.(axes1, axes2))
end

# TBD move to GradedArrays?
checkaxes_dual(axes1, axes2) = checkaxes(axes1, dual.(axes2))
function checkaxes(ax1, ax2)
return checkaxes(Bool, ax1, ax2) ||
throw(DimensionMismatch(lazy"$ax1 does not match $ax2"))
end

function to_blockindexrange(b1::BlockIndexRange{1}, b2::BlockIndexRange{1})
t = (b1, b2)
return Block(Block.(t))[to_block_indices.(t)...]
Expand Down Expand Up @@ -260,9 +244,9 @@ end
function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
# find outer block corresponding to fusion trees
@assert typeof((f1, f2)) === keytype(trees_block_mapping(ft))
b1 = find_sector_block.(leaves(f1), codomain_axes(ft))
b2 = find_sector_block.(leaves(f2), domain_axes(ft))
return Block(b1..., b2...)
b1 = findfirstblock.(flip_dual.(codomain_axes(ft)), leaves(f1))
b2 = findfirstblock.(flip_dual.(domain_axes(ft)), leaves(f2))
return Block(Int.(b1)..., Int.(b2)...)
end

# ============================== GradedArrays interface ==================================
Expand Down
8 changes: 4 additions & 4 deletions src/fusiontensor/linear_algebra_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using LinearAlgebra: LinearAlgebra, mul!, norm, tr
using BlockArrays: Block, blocks

using BlockSparseArrays: eachblockstoredindex
using GradedArrays: quantum_dimension, sectors
using GradedArrays: checkspaces, checkspaces_dual, quantum_dimension, sectors

# allow to contract with different eltype and let BlockSparseArray ensure compatibility
# impose matching type and number of axes at compile time
Expand All @@ -27,9 +27,9 @@ function LinearAlgebra.mul!(
end

# input validation
checkaxes_dual(domain_axes(A), codomain_axes(B))
checkaxes(codomain_axes(C), codomain_axes(A))
checkaxes(domain_axes(C), domain_axes(B))
checkspaces_dual(domain_axes(A), codomain_axes(B))
checkspaces(codomain_axes(C), codomain_axes(A))
checkspaces(domain_axes(C), domain_axes(B))
mul!(data_matrix(C), data_matrix(A), data_matrix(B), α, β)
return C
end
Expand Down
9 changes: 4 additions & 5 deletions test/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ using FusionTensors:
codomain_axes,
data_matrix,
domain_axes,
checkaxes,
checkaxes_dual,
domain_axis,
codomain_axis,
ndims_codomain,
ndims_domain
using GradedArrays: dual, sectors, sector_multiplicities, space_isequal
using GradedArrays:
checkspaces, checkspaces_dual, dual, sectors, sector_multiplicities, space_isequal

function check_sanity(ft::FusionTensor)
nca = ndims_domain(ft)
Expand All @@ -25,8 +24,8 @@ function check_sanity(ft::FusionTensor)
@assert nda + nca == ndims(ft) "invalid ndims"

@assert length(axes(ft)) == ndims(ft) "ndims does not match axes"
checkaxes(axes(ft)[begin:nda], codomain_axes(ft))
checkaxes(axes(ft)[(nda + 1):end], domain_axes(ft))
checkspaces(axes(ft)[begin:nda], codomain_axes(ft))
checkspaces(axes(ft)[(nda + 1):end], domain_axes(ft))

m = data_matrix(ft)
@assert ndims(m) == 2 "invalid data_matrix ndims"
Expand Down
44 changes: 22 additions & 22 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ using FusionTensors:
data_matrix,
domain_axes,
FusionTensor,
checkaxes,
checkaxes_dual,
codomain_axis,
domain_axis,
ndims_domain,
Expand All @@ -21,6 +19,8 @@ using GradedArrays:
SectorProduct,
TrivialSector,
Z,
checkspaces,
checkspaces_dual,
dual,
flip,
gradedrange,
Expand Down Expand Up @@ -54,11 +54,11 @@ include("setup.jl")

# getters
@test data_matrix(ft1) == m
@test checkaxes(axes(ft1), tuplemortar(((g1,), (g2,))))
@test axes(ft1) == FusionTensorAxes((g1,), (g2,))

# misc
@test checkaxes(codomain_axes(ft1), (g1,))
@test checkaxes(domain_axes(ft1), (g2,))
@test checkspaces(codomain_axes(ft1), (g1,))
@test checkspaces(domain_axes(ft1), (g2,))
@test ndims_codomain(ft1) == 1
@test ndims_domain(ft1) == 1
@test size(data_matrix(ft1)) == (6, 5)
Expand Down Expand Up @@ -86,42 +86,42 @@ include("setup.jl")
@test ft2 !== ft1
@test data_matrix(ft2) == data_matrix(ft1)
@test data_matrix(ft2) !== data_matrix(ft1)
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
@test checkspaces(codomain_axes(ft2), codomain_axes(ft1))
@test checkspaces(domain_axes(ft2), domain_axes(ft1))

ft2 = deepcopy(ft1)
@test ft2 !== ft1
@test data_matrix(ft2) == data_matrix(ft1)
@test data_matrix(ft2) !== data_matrix(ft1)
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
@test checkspaces(codomain_axes(ft2), codomain_axes(ft1))
@test checkspaces(domain_axes(ft2), domain_axes(ft1))

# similar
ft2 = similar(ft1)
@test isnothing(check_sanity(ft2))
@test eltype(ft2) == Float64
@test checkaxes(codomain_axes(ft2), codomain_axes(ft1))
@test checkaxes(domain_axes(ft2), domain_axes(ft1))
@test checkspaces(codomain_axes(ft2), codomain_axes(ft1))
@test checkspaces(domain_axes(ft2), domain_axes(ft1))

ft3 = similar(ft1, ComplexF64)
@test isnothing(check_sanity(ft3))
@test eltype(ft3) == ComplexF64
@test checkaxes(codomain_axes(ft3), codomain_axes(ft1))
@test checkaxes(domain_axes(ft3), domain_axes(ft1))
@test checkspaces(codomain_axes(ft3), codomain_axes(ft1))
@test checkspaces(domain_axes(ft3), domain_axes(ft1))

@test_throws AssertionError similar(ft1, Int)

ft5 = similar(ft1, ComplexF32, ((g1, g1), (g2,)))
@test isnothing(check_sanity(ft5))
@test eltype(ft5) == ComplexF64
@test checkaxes(codomain_axes(ft5), (g1, g1))
@test checkaxes(domain_axes(ft5), (g2,))
@test checkspaces(codomain_axes(ft5), (g1, g1))
@test checkspaces(domain_axes(ft5), (g2,))

ft5 = similar(ft1, ComplexF32, tuplemortar(((g1, g1), (g2,))))
@test isnothing(check_sanity(ft5))
@test eltype(ft5) == ComplexF64
@test checkaxes(codomain_axes(ft5), (g1, g1))
@test checkaxes(domain_axes(ft5), (g2,))
@test checkspaces(codomain_axes(ft5), (g1, g1))
@test checkspaces(domain_axes(ft5), (g2,))
end

@testset "More than 2 axes" begin
Expand All @@ -135,8 +135,8 @@ end
ft = FusionTensor(m2, (g1, g2), (g3, g4))

@test data_matrix(ft) == m2
@test checkaxes(codomain_axes(ft), (g1, g2))
@test checkaxes(domain_axes(ft), (g3, g4))
@test checkspaces(codomain_axes(ft), (g1, g2))
@test checkspaces(domain_axes(ft), (g3, g4))

@test axes(ft) == FusionTensorAxes(tuplemortar(((g1, g2), (g3, g4))))
@test ndims_codomain(ft) == 2
Expand Down Expand Up @@ -269,9 +269,9 @@ end
@test isnothing(check_sanity(ad))

ft7 = FusionTensor{Float64}(undef, (g1,), (g2, g3, g4))
@test_throws DimensionMismatch ft7 + ft3
@test_throws DimensionMismatch ft7 - ft3
@test_throws DimensionMismatch ft7 * ft3
@test_throws ArgumentError ft7 + ft3
@test_throws ArgumentError ft7 - ft3
@test_throws ArgumentError ft7 * ft3
end

@testset "specific constructors" begin
Expand Down
6 changes: 3 additions & 3 deletions test/test_permutedims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ using Test: @test, @testset, @test_broken, @test_throws

using FusionTensors:
FusionTensor,
FusionTensorAxes,
data_matrix,
checkaxes,
codomain_axis,
domain_axis,
naive_permutedims,
Expand Down Expand Up @@ -40,11 +40,11 @@ include("setup.jl")
ft3 = permutedims(ft1, (4,), (1, 2, 3))
@test ft3 !== ft1
@test ft3 isa FusionTensor{elt,4}
@test checkaxes(axes(ft3), tuplemortar(((dual(g4),), (g1, g2, dual(g3)))))
@test axes(ft3) == FusionTensorAxes((dual(g4),), (g1, g2, dual(g3)))
@test isnothing(check_sanity(ft3))

ft4 = permutedims(ft3, (2, 3), (4, 1))
@test checkaxes(axes(ft1), axes(ft4))
@test axes(ft1) == axes(ft4)
@test space_isequal(codomain_axis(ft1), codomain_axis(ft4))
@test space_isequal(domain_axis(ft1), domain_axis(ft4))
@test ft4 ≈ ft1
Expand Down
Loading