Skip to content

Commit 48edc0e

Browse files
committed
define matrix functions
1 parent 5472a79 commit 48edc0e

File tree

3 files changed

+70
-6
lines changed

3 files changed

+70
-6
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FusionTensors"
22
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.9"
4+
version = "0.5.11"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -28,7 +28,7 @@ LRUCache = "1.6"
2828
LinearAlgebra = "1.10"
2929
Random = "1.10"
3030
Strided = "2.3"
31-
TensorAlgebra = "0.4"
31+
TensorAlgebra = "0.4.1"
3232
TensorProducts = "0.1.7"
3333
TypeParameterAccessors = "0.4"
3434
WignerSymbols = "2.0.0"

src/fusiontensor/tensor_algebra_interface.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using TensorAlgebra:
1414
Matricize,
1515
blockedperm,
1616
genperm,
17+
matricize,
1718
unmatricize
1819

1920
function TensorAlgebra.output_axes(
@@ -70,3 +71,17 @@ function TensorAlgebra.unmatricizeadd!(a_dest::FusionTensor, a_dest_mat, invbipe
7071
data_matrix(a_dest) .= α .* data_matrix(a12) .+ β .* data_matrix(a_dest)
7172
return a_dest
7273
end
74+
75+
for f in TensorAlgebra.MATRIX_FUNCTIONS
76+
@eval begin
77+
function TensorAlgebra.$f(
78+
a::FusionTensor, biperm::AbstractBlockPermutation{2}; kwargs...
79+
)
80+
a_mat = matricize(a, biperm)
81+
permuted_axes = axes(a)[biperm]
82+
checkspaces_dual(codomain(permuted_axes), domain(permuted_axes))
83+
fa_mat = set_data_matrix(a_mat, Base.$f(data_matrix(a_mat); kwargs...))
84+
return unmatricize(fa_mat, permuted_axes)
85+
end
86+
end
87+
end

test/test_contraction.jl

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
using LinearAlgebra: mul!
2-
using Test: @test, @testset, @test_broken
2+
using Test: @test, @testset, @test_broken, @test_throws
33

44
using BlockSparseArrays: BlockSparseArray
55
using FusionTensors:
6-
FusionMatrix, FusionTensor, FusionTensorAxes, domain_axes, codomain_axes
7-
using GradedArrays: U1, dual, gradedrange
8-
using TensorAlgebra: contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize!
6+
FusionMatrix, FusionTensor, FusionTensorAxes, domain_axes, codomain_axes, to_fusiontensor
7+
using GradedArrays: SU2, U1, dual, gradedrange
8+
using TensorAlgebra:
9+
TensorAlgebra, contract, matricize, permmortar, tuplemortar, unmatricize, unmatricize!
910

1011
include("setup.jl")
1112

@@ -33,6 +34,54 @@ include("setup.jl")
3334
@test ft1 ft2
3435
end
3536

37+
@testset "Matrix functions" begin
38+
sds22 = [
39+
0.25 0.0 0.0 0.0
40+
0.0 -0.25 0.5 0.0
41+
0.0 0.5 -0.25 0.0
42+
0.0 0.0 0.0 0.25
43+
]
44+
t = reshape(sds22, (2, 2, 2, 2))
45+
g2 = gradedrange([SU2(1//2) => 1])
46+
ft = to_fusiontensor(t, (g2, g2), (dual(g2), dual(g2)))
47+
for f in [
48+
:exp,
49+
:cis,
50+
:log,
51+
:sqrt,
52+
:cbrt,
53+
:cos,
54+
:sin,
55+
:tan,
56+
:csc,
57+
:sec,
58+
:cot,
59+
:cosh,
60+
:sinh,
61+
:tanh,
62+
:csch,
63+
:sech,
64+
:coth,
65+
:acos,
66+
:asin,
67+
:atan,
68+
:acsc,
69+
:asec,
70+
:acot,
71+
:acosh,
72+
:asinh,
73+
:atanh,
74+
:acsch,
75+
:asech,
76+
]
77+
t2 = @eval reshape(Base.$f(sds22), (2, 2, 2, 2))
78+
ft2 = to_fusiontensor(t2, (g2, g2), (dual(g2), dual(g2)))
79+
@test (@eval TensorAlgebra.$f(ft, (1, 2, 3, 4), (1, 2), (3, 4))) ft2
80+
end
81+
@test_throws ArgumentError TensorAlgebra.exp(ft, (1, 2, 3, 4), (1, 2, 3), (4,))
82+
@test_throws ArgumentError TensorAlgebra.exp(ft, (1, 2, 3, 4), (1, 3), (2, 4))
83+
end
84+
3685
@testset "contraction" begin
3786
g1 = gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 3])
3887
g2 = gradedrange([U1(0) => 2, U1(1) => 2, U1(3) => 1])

0 commit comments

Comments
 (0)