Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FusionTensors"
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.4.1"
version = "0.4.3"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -11,6 +11,7 @@ GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
Expand All @@ -25,6 +26,7 @@ GradedArrays = "0.4.13"
HalfIntegers = "1.6"
LRUCache = "1.6"
LinearAlgebra = "1.10"
Random = "1.11.0"
Strided = "2.3"
TensorAlgebra = "0.3.8"
TensorProducts = "0.1.7"
Expand Down
4 changes: 2 additions & 2 deletions src/fusiontensor/base_interface.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This files defines Base functions for FusionTensor

using Accessors: @set

using BlockSparseArrays: @view!
using Random: randn!
using BlockSparseArrays: @view!, eachstoredblock
using TensorAlgebra: BlockedTuple, tuplemortar

set_data_matrix(ft::FusionTensor, data_matrix) = @set ft.data_matrix = data_matrix
Expand Down
25 changes: 25 additions & 0 deletions src/fusiontensor/fusiontensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using GradedArrays:
sectormergesort,
sectors,
space_isequal
using LinearAlgebra: UniformScaling
using TensorAlgebra: BlockedTuple, trivial_axis, tuplemortar
using TensorProducts: tensor_product
using TypeParameterAccessors: type_parameters
Expand Down Expand Up @@ -209,6 +210,30 @@ function FusionTensor(
return FusionTensor(x, tuplemortar((codomain_legs, domain_legs)))
end

# specific constructors
Base.zeros(::Type{T}, fta::FusionTensorAxes) where {T} = FusionTensor(T, fta)
Base.zeros(fta::FusionTensorAxes) = zeros(Float64, fta)

function Base.randn(::Type{T}, fta::FusionTensorAxes) where {T}
ft = FusionTensor(T, fta)
for m in eachstoredblock(data_matrix(ft))
m = randn!(m)
end
return ft
end
Base.randn(fta::FusionTensorAxes) = randn(Float64, fta)

function FusionTensor(
::UniformScaling, codomain_legs::Tuple{Vararg{AbstractGradedUnitRange}}
)
fta = FusionTensorAxes(codomain_legs, dual.(codomain_legs))
ft = FusionTensor(Float64, fta)
for m in eachstoredblock(data_matrix(ft))
m .= LinearAlgebra.I(size(m, 1))
end
return ft
end

# ================================ BlockArrays interface =================================

function BlockArrays.findblock(ft::FusionTensor, f1::SectorFusionTree, f2::SectorFusionTree)
Expand Down
33 changes: 32 additions & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Test: @test, @test_throws, @testset

using BlockArrays: Block
using BlockSparseArrays: BlockSparseArray
using BlockSparseArrays: BlockSparseArray, eachblockstoredindex
using FusionTensors:
FusionTensor,
FusionTensorAxes,
Expand All @@ -28,13 +28,20 @@ using GradedArrays:
space_isequal
using TensorAlgebra: tuplemortar
using TensorProducts: tensor_product
using LinearAlgebra: LinearAlgebra

include("setup.jl")

@testset "Fusion matrix" begin
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
g2 = dual(gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1]))

fta = FusionTensorAxes((g1,), (g2,))
ft0 = FusionTensor(Float64, fta)
@test ft0 isa FusionTensor
@test space_isequal(codomain_axis(ft0), g1)
@test space_isequal(domain_axis(ft0), g2)

# check dual convention when initializing data_matrix
ft0 = FusionTensor(Float64, (g1,), (g2,))
@test ft0 isa FusionTensor
Expand Down Expand Up @@ -266,6 +273,30 @@ end
@test_throws DimensionMismatch ft7 * ft3
end

@testset "specific constructors" begin
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
g4 = gradedrange([U1(-1) => 1, U1(0) => 1, U1(1) => 1])

fta = FusionTensorAxes((g1,), (g2, g3))
@test zeros(fta) isa FusionTensor{Float64,3}
@test zeros(ComplexF64, fta) isa FusionTensor{ComplexF64,3}
ft1 = randn(ComplexF64, fta)
@test ft1 isa FusionTensor{ComplexF64,3}
@test all(data_matrix(ft1)[Block(1, 5)] .!= 0)
@test randn(fta) isa FusionTensor{Float64,3}

ft2 = FusionTensor(LinearAlgebra.I, (g1, g2))
@test ft2 isa FusionTensor{Float64,4}
@test axes(ft2) == FusionTensorAxes((g1, g2), dual.((g1, g2)))
@test collect(eachblockstoredindex(data_matrix(ft2))) == map(i -> Block(i, i), 1:6)
for i in 1:6
m = data_matrix(ft2)[Block(i, i)]
@test m == LinearAlgebra.I(size(m, 1))
end
end

@testset "missing SectorProduct" begin
g1 = gradedrange([SectorProduct(U1(1)) => 1])
g2 = gradedrange([SectorProduct(U1(1), SU2(1//2)) => 1])
Expand Down
Loading