diff --git a/Project.toml b/Project.toml index ce7c92e..8ec8728 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FusionTensors" uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e" authors = ["ITensor developers and contributors"] -version = "0.5.14" +version = "0.5.15" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -29,7 +29,7 @@ LRUCache = "1.6" LinearAlgebra = "1.10" Random = "1.10" Strided = "2.3" -TensorAlgebra = "0.4" +TensorAlgebra = "0.5.1" TensorKitSectors = "0.1, 0.2" TensorProducts = "0.1.7" TypeParameterAccessors = "0.4" diff --git a/src/fusiontensor/fusiontensoraxes.jl b/src/fusiontensor/fusiontensoraxes.jl index efc9799..01303a6 100644 --- a/src/fusiontensor/fusiontensoraxes.jl +++ b/src/fusiontensor/fusiontensoraxes.jl @@ -44,7 +44,7 @@ end # ==================================== Definitions ======================================= # TBD explicit axis type as type parameters? -struct FusionTensorAxes{BT <: BlockedTuple{2}} +struct FusionTensorAxes{BT <: AbstractBlockTuple{2}} <: AbstractBlockTuple{2} outer_axes::BT function FusionTensorAxes{BT}(bt) where {BT} @@ -75,19 +75,19 @@ TensorAlgebra.length_domain(fta::FusionTensorAxes) = length(domain(fta)) # ================================== Base interface ====================================== for f in [ - :(broadcastable), :(Tuple), :(axes), :(firstindex), :(lastindex), :(iterate), :(length), + :(broadcastable), :(Tuple), :(axes), :(firstindex), :(lastindex), :(length), ] @eval Base.$f(fta::FusionTensorAxes) = Base.$f(BlockedTuple(fta)) end -for f in [:(getindex), :(iterate)] - @eval Base.$f(fta::FusionTensorAxes, i) = $f(BlockedTuple(fta), i) -end - +Base.getindex(fta::FusionTensorAxes, i::Int) = BlockedTuple(fta)[i] function Base.getindex(fta::FusionTensorAxes, bp::AbstractBlockPermutation) return FusionTensorAxes(BlockedTuple(fta)[bp]) end +Base.iterate(fta::FusionTensorAxes) = iterate(BlockedTuple(fta)) +Base.iterate(fta::FusionTensorAxes, state::Int) = iterate(BlockedTuple(fta), state) + Base.copy(fta::FusionTensorAxes) = FusionTensorAxes(copy.(BlockedTuple(fta))) Base.deepcopy(fta::FusionTensorAxes) = FusionTensorAxes(deepcopy.(BlockedTuple(fta))) diff --git a/src/fusiontensor/tensor_algebra_interface.jl b/src/fusiontensor/tensor_algebra_interface.jl index c5f3fd3..ade110c 100644 --- a/src/fusiontensor/tensor_algebra_interface.jl +++ b/src/fusiontensor/tensor_algebra_interface.jl @@ -1,53 +1,10 @@ # This file defines TensorAlgebra interface for a FusionTensor -using LinearAlgebra: mul! - using BlockArrays: Block - using GradedArrays: space_isequal -using TensorAlgebra: - TensorAlgebra, - AbstractBlockPermutation, - BlockedTrivialPermutation, - BlockedTuple, - FusionStyle, - Matricize, - blockedperm, - genperm, - matricize, - unmatricize - -const MATRIX_FUNCTIONS = [ - :exp, - :cis, - :log, - :sqrt, - :cbrt, - :cos, - :sin, - :tan, - :csc, - :sec, - :cot, - :cosh, - :sinh, - :tanh, - :csch, - :sech, - :coth, - :acos, - :asin, - :atan, - :acsc, - :asec, - :acot, - :acosh, - :asinh, - :atanh, - :acsch, - :asech, - :acoth, -] +using LinearAlgebra: mul! +using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, blockedperm, + genperm, matricize, unmatricize function TensorAlgebra.output_axes( ::typeof(contract), @@ -75,43 +32,98 @@ struct FusionTensorFusionStyle <: FusionStyle end TensorAlgebra.FusionStyle(::Type{<:FusionTensor}) = FusionTensorFusionStyle() +unval(::Val{x}) where {x} = x + function TensorAlgebra.matricize( - ::FusionTensorFusionStyle, ft::AbstractArray, biperm::BlockedTrivialPermutation{2} + ::FusionTensorFusionStyle, ft::AbstractArray, + codomain_length::Val, domain_length::Val ) - blocklengths(biperm) == blocklengths(axes(ft)) || + blocklengths(axes(ft)) == unval.((codomain_length, domain_length)) || throw(ArgumentError("Invalid trivial biperm")) return FusionTensor(data_matrix(ft), (codomain_axis(ft),), (domain_axis(ft),)) end -function TensorAlgebra.unmatricize(::FusionTensorFusionStyle, m, blocked_axes) - return FusionTensor(data_matrix(m), blocked_axes) +function TensorAlgebra.unmatricize( + ::FusionTensorFusionStyle, + m::AbstractMatrix, + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, + domain_axes::Tuple{Vararg{AbstractUnitRange}}, + ) + return FusionTensor(data_matrix(m), codomain_axes, domain_axes) end function TensorAlgebra.permuteblockeddims( - ft::FusionTensor, biperm::AbstractBlockPermutation + ft::FusionTensor, + codomain_perm::Tuple{Vararg{Int}}, + domain_perm::Tuple{Vararg{Int}}, ) - return permutedims(ft, biperm) + return permutedims(ft, permmortar((codomain_perm, domain_perm))) end function TensorAlgebra.permuteblockeddims!( - a::FusionTensor, b::FusionTensor, biperm::AbstractBlockPermutation + a_dest::FusionTensor, + a_src::FusionTensor, + codomain_perm::Tuple{Vararg{Int}}, + domain_perm::Tuple{Vararg{Int}}, ) - return permutedims!(a, b, biperm) + return permutedims!(a_dest, a_src, permmortar((codomain_perm, domain_perm))) end # TODO define custom broadcast rules -function TensorAlgebra.unmatricizeadd!(a_dest::FusionTensor, a_dest_mat, invbiperm, α, β) - a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm) +function TensorAlgebra.unmatricizeadd!( + style::FusionTensorFusionStyle, + a_dest::AbstractArray, + a_dest_mat::AbstractMatrix, + codomain_perm::Tuple{Vararg{Int}}, + domain_perm::Tuple{Vararg{Int}}, + α::Number, β::Number, + ) + a12 = unmatricize(a_dest_mat, axes(a_dest), codomain_perm, domain_perm) data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest) return a_dest end +const MATRIX_FUNCTIONS = [ + :exp, + :cis, + :log, + :sqrt, + :cbrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, +] + for f in MATRIX_FUNCTIONS @eval begin function TensorAlgebra.$f( - a::FusionTensor, biperm::AbstractBlockPermutation{2}; kwargs... + a::FusionTensor, + codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}}; + kwargs..., ) - a_mat = matricize(a, biperm) + a_mat = matricize(a, codomain_perm, domain_perm) + biperm = permmortar((codomain_perm, domain_perm)) permuted_axes = axes(a)[biperm] checkspaces_dual(codomain(permuted_axes), domain(permuted_axes)) fa_mat = set_data_matrix(a_mat, Base.$f(data_matrix(a_mat); kwargs...)) diff --git a/test/Project.toml b/test/Project.toml index 3fb032e..af5b399 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,9 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources] +FusionTensors = {path = ".."} + [compat] Aqua = "0.8.11" BlockArrays = "1.6" @@ -24,6 +27,6 @@ Random = "1.10" SUNRepresentations = "0.3.1" SafeTestsets = "0.1.0" Suppressor = "0.2.8" -TensorAlgebra = "0.4" +TensorAlgebra = "0.5" TensorProducts = "0.1" Test = "1.10.0"