Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FusionTensors"
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.14"
version = "0.5.15"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down Expand Up @@ -29,7 +29,7 @@ LRUCache = "1.6"
LinearAlgebra = "1.10"
Random = "1.10"
Strided = "2.3"
TensorAlgebra = "0.4"
TensorAlgebra = "0.5.1"
TensorKitSectors = "0.1, 0.2"
TensorProducts = "0.1.7"
TypeParameterAccessors = "0.4"
Expand Down
12 changes: 6 additions & 6 deletions src/fusiontensor/fusiontensoraxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ end
# ==================================== Definitions =======================================

# TBD explicit axis type as type parameters?
struct FusionTensorAxes{BT <: BlockedTuple{2}}
struct FusionTensorAxes{BT <: AbstractBlockTuple{2}} <: AbstractBlockTuple{2}
outer_axes::BT

function FusionTensorAxes{BT}(bt) where {BT}
Expand Down Expand Up @@ -75,19 +75,19 @@ TensorAlgebra.length_domain(fta::FusionTensorAxes) = length(domain(fta))
# ================================== Base interface ======================================

for f in [
:(broadcastable), :(Tuple), :(axes), :(firstindex), :(lastindex), :(iterate), :(length),
:(broadcastable), :(Tuple), :(axes), :(firstindex), :(lastindex), :(length),
]
@eval Base.$f(fta::FusionTensorAxes) = Base.$f(BlockedTuple(fta))
end

for f in [:(getindex), :(iterate)]
@eval Base.$f(fta::FusionTensorAxes, i) = $f(BlockedTuple(fta), i)
end

Base.getindex(fta::FusionTensorAxes, i::Int) = BlockedTuple(fta)[i]
function Base.getindex(fta::FusionTensorAxes, bp::AbstractBlockPermutation)
return FusionTensorAxes(BlockedTuple(fta)[bp])
end

Base.iterate(fta::FusionTensorAxes) = iterate(BlockedTuple(fta))
Base.iterate(fta::FusionTensorAxes, state::Int) = iterate(BlockedTuple(fta), state)

Base.copy(fta::FusionTensorAxes) = FusionTensorAxes(copy.(BlockedTuple(fta)))

Base.deepcopy(fta::FusionTensorAxes) = FusionTensorAxes(deepcopy.(BlockedTuple(fta)))
Expand Down
128 changes: 70 additions & 58 deletions src/fusiontensor/tensor_algebra_interface.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,10 @@
# This file defines TensorAlgebra interface for a FusionTensor

using LinearAlgebra: mul!

using BlockArrays: Block

using GradedArrays: space_isequal
using TensorAlgebra:
TensorAlgebra,
AbstractBlockPermutation,
BlockedTrivialPermutation,
BlockedTuple,
FusionStyle,
Matricize,
blockedperm,
genperm,
matricize,
unmatricize

const MATRIX_FUNCTIONS = [
:exp,
:cis,
:log,
:sqrt,
:cbrt,
:cos,
:sin,
:tan,
:csc,
:sec,
:cot,
:cosh,
:sinh,
:tanh,
:csch,
:sech,
:coth,
:acos,
:asin,
:atan,
:acsc,
:asec,
:acot,
:acosh,
:asinh,
:atanh,
:acsch,
:asech,
:acoth,
]
using LinearAlgebra: mul!
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, blockedperm,
genperm, matricize, unmatricize

function TensorAlgebra.output_axes(
::typeof(contract),
Expand Down Expand Up @@ -75,43 +32,98 @@ struct FusionTensorFusionStyle <: FusionStyle end

TensorAlgebra.FusionStyle(::Type{<:FusionTensor}) = FusionTensorFusionStyle()

unval(::Val{x}) where {x} = x

function TensorAlgebra.matricize(
::FusionTensorFusionStyle, ft::AbstractArray, biperm::BlockedTrivialPermutation{2}
::FusionTensorFusionStyle, ft::AbstractArray,
codomain_length::Val, domain_length::Val
)
blocklengths(biperm) == blocklengths(axes(ft)) ||
blocklengths(axes(ft)) == unval.((codomain_length, domain_length)) ||
throw(ArgumentError("Invalid trivial biperm"))
return FusionTensor(data_matrix(ft), (codomain_axis(ft),), (domain_axis(ft),))
end

function TensorAlgebra.unmatricize(::FusionTensorFusionStyle, m, blocked_axes)
return FusionTensor(data_matrix(m), blocked_axes)
function TensorAlgebra.unmatricize(
::FusionTensorFusionStyle,
m::AbstractMatrix,
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
domain_axes::Tuple{Vararg{AbstractUnitRange}},
)
return FusionTensor(data_matrix(m), codomain_axes, domain_axes)
end

function TensorAlgebra.permuteblockeddims(
ft::FusionTensor, biperm::AbstractBlockPermutation
ft::FusionTensor,
codomain_perm::Tuple{Vararg{Int}},
domain_perm::Tuple{Vararg{Int}},
)
return permutedims(ft, biperm)
return permutedims(ft, permmortar((codomain_perm, domain_perm)))
end

function TensorAlgebra.permuteblockeddims!(
a::FusionTensor, b::FusionTensor, biperm::AbstractBlockPermutation
a_dest::FusionTensor,
a_src::FusionTensor,
codomain_perm::Tuple{Vararg{Int}},
domain_perm::Tuple{Vararg{Int}},
)
return permutedims!(a, b, biperm)
return permutedims!(a_dest, a_src, permmortar((codomain_perm, domain_perm)))
end

# TODO define custom broadcast rules
function TensorAlgebra.unmatricizeadd!(a_dest::FusionTensor, a_dest_mat, invbiperm, α, β)
a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm)
function TensorAlgebra.unmatricizeadd!(
style::FusionTensorFusionStyle,
a_dest::AbstractArray,
a_dest_mat::AbstractMatrix,
codomain_perm::Tuple{Vararg{Int}},
domain_perm::Tuple{Vararg{Int}},
α::Number, β::Number,
)
a12 = unmatricize(a_dest_mat, axes(a_dest), codomain_perm, domain_perm)
data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest)
return a_dest
end

const MATRIX_FUNCTIONS = [
:exp,
:cis,
:log,
:sqrt,
:cbrt,
:cos,
:sin,
:tan,
:csc,
:sec,
:cot,
:cosh,
:sinh,
:tanh,
:csch,
:sech,
:coth,
:acos,
:asin,
:atan,
:acsc,
:asec,
:acot,
:acosh,
:asinh,
:atanh,
:acsch,
:asech,
:acoth,
]

for f in MATRIX_FUNCTIONS
@eval begin
function TensorAlgebra.$f(
a::FusionTensor, biperm::AbstractBlockPermutation{2}; kwargs...
a::FusionTensor,
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
kwargs...,
)
a_mat = matricize(a, biperm)
a_mat = matricize(a, codomain_perm, domain_perm)
biperm = permmortar((codomain_perm, domain_perm))
permuted_axes = axes(a)[biperm]
checkspaces_dual(codomain(permuted_axes), domain(permuted_axes))
fa_mat = set_data_matrix(a_mat, Base.$f(data_matrix(a_mat); kwargs...))
Expand Down
5 changes: 4 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
FusionTensors = {path = ".."}

[compat]
Aqua = "0.8.11"
BlockArrays = "1.6"
Expand All @@ -24,6 +27,6 @@ Random = "1.10"
SUNRepresentations = "0.3.1"
SafeTestsets = "0.1.0"
Suppressor = "0.2.8"
TensorAlgebra = "0.4"
TensorAlgebra = "0.5"
TensorProducts = "0.1"
Test = "1.10.0"
Loading