Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FusionTensors"
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.10"
version = "0.5.11"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -22,7 +22,7 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b"
Accessors = "0.1.42"
BlockArrays = "1.7"
BlockSparseArrays = "0.10"
GradedArrays = "0.4.14"
GradedArrays = "0.4"
HalfIntegers = "1.6"
LRUCache = "1.6"
LinearAlgebra = "1.10"
Expand Down
47 changes: 47 additions & 0 deletions src/fusiontensor/tensor_algebra_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,41 @@ using TensorAlgebra:
Matricize,
blockedperm,
genperm,
matricize,
unmatricize

const MATRIX_FUNCTIONS = [
:exp,
:cis,
:log,
:sqrt,
:cbrt,
:cos,
:sin,
:tan,
:csc,
:sec,
:cot,
:cosh,
:sinh,
:tanh,
:csch,
:sech,
:coth,
:acos,
:asin,
:atan,
:acsc,
:asec,
:acot,
:acosh,
:asinh,
:atanh,
:acsch,
:asech,
:acoth,
]

function TensorAlgebra.output_axes(
::typeof(contract),
biperm_dest::AbstractBlockPermutation{2},
Expand Down Expand Up @@ -70,3 +103,17 @@ function TensorAlgebra.unmatricizeadd!(a_dest::FusionTensor, a_dest_mat, invbipe
data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest)
return a_dest
end

for f in MATRIX_FUNCTIONS
@eval begin
function TensorAlgebra.$f(
a::FusionTensor, biperm::AbstractBlockPermutation{2}; kwargs...
)
a_mat = matricize(a, biperm)
permuted_axes = axes(a)[biperm]
checkspaces_dual(codomain(permuted_axes), domain(permuted_axes))
fa_mat = set_data_matrix(a_mat, Base.$f(data_matrix(a_mat); kwargs...))
return unmatricize(fa_mat, permuted_axes)
end
end
end
29 changes: 25 additions & 4 deletions test/test_contraction.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
using LinearAlgebra: mul!
using Test: @test, @testset, @test_broken
using Test: @test, @testset, @test_broken, @test_throws

using BlockSparseArrays: BlockSparseArray
using FusionTensors: FusionTensor, FusionTensorAxes, domain_axes, codomain_axes
using FusionTensors:
FusionTensors, FusionTensor, FusionTensorAxes, domain_axes, codomain_axes, to_fusiontensor
using GradedArrays: SU2, U1, dual, gradedrange
using TensorAlgebra: contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize!
using TensorAlgebra:
TensorAlgebra, contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize!

include("setup.jl")

Expand Down Expand Up @@ -52,6 +54,25 @@ end
@test ft1 ≈ ft2
end

@testset "Matrix functions" begin
sds22 = [
0.25 0.0 0.0 0.0
0.0 -0.25 0.5 0.0
0.0 0.5 -0.25 0.0
0.0 0.0 0.0 0.25
]
t = reshape(sds22, (2, 2, 2, 2))
g2 = gradedrange([SU2(1//2) => 1])
ft = to_fusiontensor(t, (g2, g2), (dual(g2), dual(g2)))
for f in setdiff(FusionTensors.MATRIX_FUNCTIONS, [:acoth, :cbrt])
t2 = reshape((@eval Base.$f)(sds22), (2, 2, 2, 2))
ft2 = to_fusiontensor(t2, (g2, g2), (dual(g2), dual(g2)))
@test (@eval TensorAlgebra.$f)(ft, (1, 2, 3, 4), (1, 2), (3, 4)) ≈ ft2
end
@test_throws ArgumentError TensorAlgebra.exp(ft, (1, 2, 3, 4), (1, 2, 3), (4,))
@test_throws ArgumentError TensorAlgebra.exp(ft, (1, 2, 3, 4), (1, 3), (2, 4))
end

@testset "contraction" begin
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
Expand All @@ -78,7 +99,7 @@ end
@test m3 ≈ 2m1 * m2
end

@testset "TensorAlgebra interface" begin
@testset "TensorAlgebra.contract interface" begin
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])
Expand Down
Loading