Skip to content

Commit 07fe41d

Browse files
authored
define matrix functions (#80)
1 parent c81d2e6 commit 07fe41d

File tree

3 files changed

+74
-6
lines changed

3 files changed

+74
-6
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FusionTensors"
22
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.10"
4+
version = "0.5.11"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -22,7 +22,7 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b"
2222
Accessors = "0.1.42"
2323
BlockArrays = "1.7"
2424
BlockSparseArrays = "0.10"
25-
GradedArrays = "0.4.14"
25+
GradedArrays = "0.4"
2626
HalfIntegers = "1.6"
2727
LRUCache = "1.6"
2828
LinearAlgebra = "1.10"

src/fusiontensor/tensor_algebra_interface.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,41 @@ using TensorAlgebra:
1414
Matricize,
1515
blockedperm,
1616
genperm,
17+
matricize,
1718
unmatricize
1819

20+
const MATRIX_FUNCTIONS = [
21+
:exp,
22+
:cis,
23+
:log,
24+
:sqrt,
25+
:cbrt,
26+
:cos,
27+
:sin,
28+
:tan,
29+
:csc,
30+
:sec,
31+
:cot,
32+
:cosh,
33+
:sinh,
34+
:tanh,
35+
:csch,
36+
:sech,
37+
:coth,
38+
:acos,
39+
:asin,
40+
:atan,
41+
:acsc,
42+
:asec,
43+
:acot,
44+
:acosh,
45+
:asinh,
46+
:atanh,
47+
:acsch,
48+
:asech,
49+
:acoth,
50+
]
51+
1952
function TensorAlgebra.output_axes(
2053
::typeof(contract),
2154
biperm_dest::AbstractBlockPermutation{2},
@@ -70,3 +103,17 @@ function TensorAlgebra.unmatricizeadd!(a_dest::FusionTensor, a_dest_mat, invbipe
70103
data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest)
71104
return a_dest
72105
end
106+
107+
for f in MATRIX_FUNCTIONS
108+
@eval begin
109+
function TensorAlgebra.$f(
110+
a::FusionTensor, biperm::AbstractBlockPermutation{2}; kwargs...
111+
)
112+
a_mat = matricize(a, biperm)
113+
permuted_axes = axes(a)[biperm]
114+
checkspaces_dual(codomain(permuted_axes), domain(permuted_axes))
115+
fa_mat = set_data_matrix(a_mat, Base.$f(data_matrix(a_mat); kwargs...))
116+
return unmatricize(fa_mat, permuted_axes)
117+
end
118+
end
119+
end

test/test_contraction.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
using LinearAlgebra: mul!
2-
using Test: @test, @testset, @test_broken
2+
using Test: @test, @testset, @test_broken, @test_throws
33

44
using BlockSparseArrays: BlockSparseArray
5-
using FusionTensors: FusionTensor, FusionTensorAxes, domain_axes, codomain_axes
5+
using FusionTensors:
6+
FusionTensors, FusionTensor, FusionTensorAxes, domain_axes, codomain_axes, to_fusiontensor
67
using GradedArrays: SU2, U1, dual, gradedrange
7-
using TensorAlgebra: contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize!
8+
using TensorAlgebra:
9+
TensorAlgebra, contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize!
810

911
include("setup.jl")
1012

@@ -52,6 +54,25 @@ end
5254
@test ft1 ft2
5355
end
5456

57+
@testset "Matrix functions" begin
58+
sds22 = [
59+
0.25 0.0 0.0 0.0
60+
0.0 -0.25 0.5 0.0
61+
0.0 0.5 -0.25 0.0
62+
0.0 0.0 0.0 0.25
63+
]
64+
t = reshape(sds22, (2, 2, 2, 2))
65+
g2 = gradedrange([SU2(1//2) => 1])
66+
ft = to_fusiontensor(t, (g2, g2), (dual(g2), dual(g2)))
67+
for f in setdiff(FusionTensors.MATRIX_FUNCTIONS, [:acoth, :cbrt])
68+
t2 = reshape((@eval Base.$f)(sds22), (2, 2, 2, 2))
69+
ft2 = to_fusiontensor(t2, (g2, g2), (dual(g2), dual(g2)))
70+
@test (@eval TensorAlgebra.$f)(ft, (1, 2, 3, 4), (1, 2), (3, 4)) ft2
71+
end
72+
@test_throws ArgumentError TensorAlgebra.exp(ft, (1, 2, 3, 4), (1, 2, 3), (4,))
73+
@test_throws ArgumentError TensorAlgebra.exp(ft, (1, 2, 3, 4), (1, 3), (2, 4))
74+
end
75+
5576
@testset "contraction" begin
5677
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
5778
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
@@ -78,7 +99,7 @@ end
7899
@test m3 2m1 * m2
79100
end
80101

81-
@testset "TensorAlgebra interface" begin
102+
@testset "TensorAlgebra.contract interface" begin
82103
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
83104
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])
84105
g3 = gradedrange([U1(-1) => 1, U1(0) => 2, U1(1) => 1])

0 commit comments

Comments
 (0)