Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FusionTensors"
uuid = "e16ca583-1f51-4df0-8e12-57d32947d33e"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.9"
version = "0.5.10"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
22 changes: 7 additions & 15 deletions src/fusiontensor/linear_algebra_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,17 @@ function LinearAlgebra.norm(ft::FusionTensor)
return sqrt(n2)
end

LinearAlgebra.normalize(ft::FusionTensor) = set_data_matrix(ft, data_matrix(ft) / norm(ft))

function LinearAlgebra.normalize!(ft::FusionTensor)
data_matrix(ft) ./= norm(ft)
return ft
end

function LinearAlgebra.tr(ft::FusionTensor)
m = data_matrix(ft)
row_sectors = sectors(codomain_axis(ft))
return sum(eachblockstoredindex(m); init=zero(eltype(ft))) do b
return quantum_dimension(row_sectors[Int(first(Tuple(b)))]) * tr(m[b])
end
end

function LinearAlgebra.qr(ft::FusionTensor)
qmat, rmat = block_qr(data_matrix(ft))
qtens = FusionTensor(qmat, codomain_axes(ft), (axes(qmat, 2),))
rtens = FusionTensor(rmat, (axes(rmat, 1),), domain_axes(ft))
return qtens, rtens
end

function LinearAlgebra.svd(ft::FusionTensor)
umat, s, vmat = block_svd(data_matrix(ft))
utens = FusionTensor(umat, codomain_axes(ft), (axes(umat, 2),))
stens = FusionTensor(s, (axes(umat, 1),), (axes(vmat, 2),))
vtens = FusionTensor(vmat, (axes(vmat, 1),), domain_axes(ft))
return utens, stens, vtens
end
9 changes: 8 additions & 1 deletion test/test_linear_algebra.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearAlgebra: norm, tr
using LinearAlgebra: norm, normalize, normalize!, tr
using Test: @test, @testset

using BlockArrays: BlockArrays
Expand Down Expand Up @@ -27,5 +27,12 @@ include("setup.jl")
@test isnothing(check_sanity(ft))
@test norm(ft) ≈ √3 / 2
@test isapprox(tr(ft), 0; atol=eps(Float64))

ft2 = normalize(ft)
@test norm(ft2) ≈ 1.0
@test norm(ft) ≈ √3 / 2 # unaffected by normalize
@test ft ≈ √3 / 2 * ft2
normalize!(ft)
@test norm(ft) ≈ 1.0
end
end
Loading