diff --git a/Project.toml b/Project.toml index 25f1f8c..46a935f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FusionTensors" uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e" authors = ["ITensor developers and contributors"] -version = "0.5.10" +version = "0.5.11" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -22,7 +22,7 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b" Accessors = "0.1.42" BlockArrays = "1.7" BlockSparseArrays = "0.10" -GradedArrays = "0.4.14" +GradedArrays = "0.4" HalfIntegers = "1.6" LRUCache = "1.6" LinearAlgebra = "1.10" diff --git a/src/fusiontensor/tensor_algebra_interface.jl b/src/fusiontensor/tensor_algebra_interface.jl index 1cf948e..1c29973 100644 --- a/src/fusiontensor/tensor_algebra_interface.jl +++ b/src/fusiontensor/tensor_algebra_interface.jl @@ -14,8 +14,41 @@ using TensorAlgebra: Matricize, blockedperm, genperm, + matricize, unmatricize +const MATRIX_FUNCTIONS = [ + :exp, + :cis, + :log, + :sqrt, + :cbrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, +] + function TensorAlgebra.output_axes( ::typeof(contract), biperm_dest::AbstractBlockPermutation{2}, @@ -70,3 +103,17 @@ function TensorAlgebra.unmatricizeadd!(a_dest::FusionTensor, a_dest_mat, invbipe data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest) return a_dest end + +for f in MATRIX_FUNCTIONS + @eval begin + function TensorAlgebra.$f( + a::FusionTensor, biperm::AbstractBlockPermutation{2}; kwargs... + ) + a_mat = matricize(a, biperm) + permuted_axes = axes(a)[biperm] + checkspaces_dual(codomain(permuted_axes), domain(permuted_axes)) + fa_mat = set_data_matrix(a_mat, Base.$f(data_matrix(a_mat); kwargs...)) + return unmatricize(fa_mat, permuted_axes) + end + end +end diff --git a/test/test_contraction.jl b/test/test_contraction.jl index a9fb202..c07d2e9 100644 --- a/test/test_contraction.jl +++ b/test/test_contraction.jl @@ -1,10 +1,12 @@ using LinearAlgebra: mul! -using Test: @test, @testset, @test_broken +using Test: @test, @testset, @test_broken, @test_throws using BlockSparseArrays: BlockSparseArray -using FusionTensors: FusionTensor, FusionTensorAxes, domain_axes, codomain_axes +using FusionTensors: + FusionTensors, FusionTensor, FusionTensorAxes, domain_axes, codomain_axes, to_fusiontensor using GradedArrays: SU2, U1, dual, gradedrange -using TensorAlgebra: contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize! +using TensorAlgebra: + TensorAlgebra, contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize! include("setup.jl") @@ -52,6 +54,25 @@ end @test ft1 ≈ ft2 end +@testset "Matrix functions" begin + sds22 = [ + 0.25 0.0 0.0 0.0 + 0.0 -0.25 0.5 0.0 + 0.0 0.5 -0.25 0.0 + 0.0 0.0 0.0 0.25 + ] + t = reshape(sds22, (2, 2, 2, 2)) + g2 = gradedrange([SU2(1//2) => 1]) + ft = to_fusiontensor(t, (g2, g2), (dual(g2), dual(g2))) + for f in setdiff(FusionTensors.MATRIX_FUNCTIONS, [:acoth, :cbrt]) + t2 = reshape((@eval Base.$f)(sds22), (2, 2, 2, 2)) + ft2 = to_fusiontensor(t2, (g2, g2), (dual(g2), dual(g2))) + @test (@eval TensorAlgebra.$f)(ft, (1, 2, 3, 4), (1, 2), (3, 4)) ≈ ft2 + end + @test_throws ArgumentError TensorAlgebra.exp(ft, (1, 2, 3, 4), (1, 2, 3), (4,)) + @test_throws ArgumentError TensorAlgebra.exp(ft, (1, 2, 3, 4), (1, 3), (2, 4)) +end + @testset "contraction" begin g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3]) g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1]) @@ -78,7 +99,7 @@ end @test m3 ≈ 2m1 * m2 end -@testset "TensorAlgebra interface" begin +@testset "TensorAlgebra.contract interface" begin g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3]) g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1]) g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])