Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FusionTensors"
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.14"
version = "0.5.15"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down Expand Up @@ -29,7 +29,7 @@ LRUCache = "1.6"
LinearAlgebra = "1.10"
Random = "1.10"
Strided = "2.3"
TensorAlgebra = "0.4"
TensorAlgebra = "0.5.1"
TensorKitSectors = "0.1, 0.2"
TensorProducts = "0.1.7"
TypeParameterAccessors = "0.4"
Expand Down
12 changes: 6 additions & 6 deletions src/fusiontensor/fusiontensoraxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ end
# ==================================== Definitions =======================================

# TBD explicit axis type as type parameters?
struct FusionTensorAxes{BT <: BlockedTuple{2}}
struct FusionTensorAxes{BT <: AbstractBlockTuple{2}} <: AbstractBlockTuple{2}
outer_axes::BT

function FusionTensorAxes{BT}(bt) where {BT}
Expand Down Expand Up @@ -75,19 +75,19 @@ TensorAlgebra.length_domain(fta::FusionTensorAxes) = length(domain(fta))
# ================================== Base interface ======================================

for f in [
:(broadcastable), :(Tuple), :(axes), :(firstindex), :(lastindex), :(iterate), :(length),
:(broadcastable), :(Tuple), :(axes), :(firstindex), :(lastindex), :(length),
]
@eval Base.$f(fta::FusionTensorAxes) = Base.$f(BlockedTuple(fta))
end

for f in [:(getindex), :(iterate)]
@eval Base.$f(fta::FusionTensorAxes, i) = $f(BlockedTuple(fta), i)
end

Base.getindex(fta::FusionTensorAxes, i::Int) = BlockedTuple(fta)[i]
function Base.getindex(fta::FusionTensorAxes, bp::AbstractBlockPermutation)
return FusionTensorAxes(BlockedTuple(fta)[bp])
end

Base.iterate(fta::FusionTensorAxes) = iterate(BlockedTuple(fta))
Base.iterate(fta::FusionTensorAxes, state::Int) = iterate(BlockedTuple(fta), state)

Base.copy(fta::FusionTensorAxes) = FusionTensorAxes(copy.(BlockedTuple(fta)))

Base.deepcopy(fta::FusionTensorAxes) = FusionTensorAxes(deepcopy.(BlockedTuple(fta)))
Expand Down
128 changes: 70 additions & 58 deletions src/fusiontensor/tensor_algebra_interface.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,10 @@
# This file defines TensorAlgebra interface for a FusionTensor

using LinearAlgebra: mul!

using BlockArrays: Block

using GradedArrays: space_isequal
using TensorAlgebra:
TensorAlgebra,
AbstractBlockPermutation,
BlockedTrivialPermutation,
BlockedTuple,
FusionStyle,
Matricize,
blockedperm,
genperm,
matricize,
unmatricize

const MATRIX_FUNCTIONS = [
:exp,
:cis,
:log,
:sqrt,
:cbrt,
:cos,
:sin,
:tan,
:csc,
:sec,
:cot,
:cosh,
:sinh,
:tanh,
:csch,
:sech,
:coth,
:acos,
:asin,
:atan,
:acsc,
:asec,
:acot,
:acosh,
:asinh,
:atanh,
:acsch,
:asech,
:acoth,
]
using LinearAlgebra: mul!
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, blockedperm,
genperm, matricize, unmatricize

function TensorAlgebra.output_axes(
::typeof(contract),
Expand Down Expand Up @@ -75,43 +32,98 @@ struct FusionTensorFusionStyle <: FusionStyle end

TensorAlgebra.FusionStyle(::Type{<:FusionTensor}) = FusionTensorFusionStyle()

unval(::Val{x}) where {x} = x

function TensorAlgebra.matricize(
::FusionTensorFusionStyle, ft::AbstractArray, biperm::BlockedTrivialPermutation{2}
::FusionTensorFusionStyle, ft::AbstractArray,
codomain_length::Val, domain_length::Val
)
blocklengths(biperm) == blocklengths(axes(ft)) ||
blocklengths(axes(ft)) == unval.((codomain_length, domain_length)) ||
throw(ArgumentError("Invalid trivial biperm"))
return FusionTensor(data_matrix(ft), (codomain_axis(ft),), (domain_axis(ft),))
end

function TensorAlgebra.unmatricize(::FusionTensorFusionStyle, m, blocked_axes)
return FusionTensor(data_matrix(m), blocked_axes)
function TensorAlgebra.unmatricize(
::FusionTensorFusionStyle,
m::AbstractMatrix,
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
domain_axes::Tuple{Vararg{AbstractUnitRange}},
)
return FusionTensor(data_matrix(m), codomain_axes, domain_axes)
end

function TensorAlgebra.permuteblockeddims(
ft::FusionTensor, biperm::AbstractBlockPermutation
ft::FusionTensor,
codomain_perm::Tuple{Vararg{Int}},
domain_perm::Tuple{Vararg{Int}},
)
return permutedims(ft, biperm)
return permutedims(ft, permmortar((codomain_perm, domain_perm)))
end

function TensorAlgebra.permuteblockeddims!(
a::FusionTensor, b::FusionTensor, biperm::AbstractBlockPermutation
a_dest::FusionTensor,
a_src::FusionTensor,
codomain_perm::Tuple{Vararg{Int}},
domain_perm::Tuple{Vararg{Int}},
)
return permutedims!(a, b, biperm)
return permutedims!(a_dest, a_src, permmortar((codomain_perm, domain_perm)))
end

# TODO define custom broadcast rules
function TensorAlgebra.unmatricizeadd!(a_dest::FusionTensor, a_dest_mat, invbiperm, α, β)
a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm)
function TensorAlgebra.unmatricizeadd!(
style::FusionTensorFusionStyle,
a_dest::AbstractArray,
a_dest_mat::AbstractMatrix,
codomain_perm::Tuple{Vararg{Int}},
domain_perm::Tuple{Vararg{Int}},
α::Number, β::Number,
)
a12 = unmatricize(a_dest_mat, axes(a_dest), codomain_perm, domain_perm)
data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest)
return a_dest
end

const MATRIX_FUNCTIONS = [
:exp,
:cis,
:log,
:sqrt,
:cbrt,
:cos,
:sin,
:tan,
:csc,
:sec,
:cot,
:cosh,
:sinh,
:tanh,
:csch,
:sech,
:coth,
:acos,
:asin,
:atan,
:acsc,
:asec,
:acot,
:acosh,
:asinh,
:atanh,
:acsch,
:asech,
:acoth,
]

for f in MATRIX_FUNCTIONS
@eval begin
function TensorAlgebra.$f(
a::FusionTensor, biperm::AbstractBlockPermutation{2}; kwargs...
a::FusionTensor,
codomain_perm::Tuple{Vararg{Int}}, domain_perm::Tuple{Vararg{Int}};
kwargs...,
)
a_mat = matricize(a, biperm)
a_mat = matricize(a, codomain_perm, domain_perm)
biperm = permmortar((codomain_perm, domain_perm))
permuted_axes = axes(a)[biperm]
checkspaces_dual(codomain(permuted_axes), domain(permuted_axes))
fa_mat = set_data_matrix(a_mat, Base.$f(data_matrix(a_mat); kwargs...))
Expand Down
6 changes: 5 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
FusionTensors = {path = "/Users/mfishman/.julia/dev/FusionTensors"}
TensorAlgebra = {path = "/Users/mfishman/.julia/dev/TensorAlgebra"}

[compat]
Aqua = "0.8.11"
BlockArrays = "1.6"
Expand All @@ -24,6 +28,6 @@ Random = "1.10"
SUNRepresentations = "0.3.1"
SafeTestsets = "0.1.0"
Suppressor = "0.2.8"
TensorAlgebra = "0.4"
TensorAlgebra = "0.5"
TensorProducts = "0.1"
Test = "1.10.0"
Loading