Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FusionTensors"
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.9"
version = "0.5.11"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -22,7 +22,7 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b"
Accessors = "0.1.42"
BlockArrays = "1.7"
BlockSparseArrays = "0.10"
GradedArrays = "0.4.14"
GradedArrays = "0.4"
HalfIntegers = "1.6"
LRUCache = "1.6"
LinearAlgebra = "1.10"
Expand Down
47 changes: 47 additions & 0 deletions src/fusiontensor/tensor_algebra_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,41 @@ using TensorAlgebra:
Matricize,
blockedperm,
genperm,
matricize,
unmatricize

const MATRIX_FUNCTIONS = [
:exp,
:cis,
:log,
:sqrt,
:cbrt,
:cos,
:sin,
:tan,
:csc,
:sec,
:cot,
:cosh,
:sinh,
:tanh,
:csch,
:sech,
:coth,
:acos,
:asin,
:atan,
:acsc,
:asec,
:acot,
:acosh,
:asinh,
:atanh,
:acsch,
:asech,
:acoth,
]

function TensorAlgebra.output_axes(
::typeof(contract),
biperm_dest::AbstractBlockPermutation{2},
Expand Down Expand Up @@ -70,3 +103,17 @@ function TensorAlgebra.unmatricizeadd!(a_dest::FusionTensor, a_dest_mat, invbipe
data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest)
return a_dest
end

for f in MATRIX_FUNCTIONS
@eval begin
function TensorAlgebra.$f(
a::FusionTensor, biperm::AbstractBlockPermutation{2}; kwargs...
)
a_mat = matricize(a, biperm)
permuted_axes = axes(a)[biperm]
checkspaces_dual(codomain(permuted_axes), domain(permuted_axes))
fa_mat = set_data_matrix(a_mat, Base.$f(data_matrix(a_mat); kwargs...))
return unmatricize(fa_mat, permuted_axes)
end
end
end
29 changes: 25 additions & 4 deletions test/test_contraction.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
using LinearAlgebra: mul!
using Test: @test, @testset, @test_broken
using Test: @test, @testset, @test_broken, @test_throws

using BlockSparseArrays: BlockSparseArray
using FusionTensors: FusionTensor, FusionTensorAxes, domain_axes, codomain_axes
using FusionTensors:
FusionTensor, FusionTensorAxes, domain_axes, codomain_axes, to_fusiontensor
using GradedArrays: SU2, U1, dual, gradedrange
using TensorAlgebra: contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize!
using TensorAlgebra:
TensorAlgebra, contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize!

include("setup.jl")

Expand Down Expand Up @@ -52,6 +54,25 @@ end
@test ft1 ≈ ft2
end

@testset "Matrix functions" begin
sds22 = [
0.25 0.0 0.0 0.0
0.0 -0.25 0.5 0.0
0.0 0.5 -0.25 0.0
0.0 0.0 0.0 0.25
]
t = reshape(sds22, (2, 2, 2, 2))
g2 = gradedrange([SU2(1//2) => 1])
ft = to_fusiontensor(t, (g2, g2), (dual(g2), dual(g2)))
for f in setdiff(FusionTensors.MATRIX_FUNCTIONS, [:acoth, :cbrt])
t2 = reshape((@eval Base.$f)(sds22), (2, 2, 2, 2))
ft2 = to_fusiontensor(t2, (g2, g2), (dual(g2), dual(g2)))
@test (@eval TensorAlgebra.$f)(ft, (1, 2, 3, 4), (1, 2), (3, 4)) ≈ ft2
end
@test_throws ArgumentError TensorAlgebra.exp(ft, (1, 2, 3, 4), (1, 2, 3), (4,))
@test_throws ArgumentError TensorAlgebra.exp(ft, (1, 2, 3, 4), (1, 3), (2, 4))
end

@testset "contraction" begin
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
Expand All @@ -78,7 +99,7 @@ end
@test m3 ≈ 2m1 * m2
end

@testset "TensorAlgebra interface" begin
@testset "TensorAlgebra.contract interface" begin
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
Expand Down
Loading