Skip to content

Commit a40d16e

Browse files
committed
[WIP] Start using TensorAlgebra.factorize
1 parent 018210b commit a40d16e

File tree

4 files changed

+35
-40
lines changed

4 files changed

+35
-40
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1212
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
13+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1314
UnallocatedArrays = "43c9e47c-e622-40fb-bf18-a09fc8c466b6"
1415
UnspecifiedTypes = "42b3faec-625b-4613-8ddc-352bf9672b8d"
1516
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
@@ -34,6 +35,7 @@ LinearAlgebra = "1.10"
3435
MapBroadcast = "0.1.5"
3536
NamedDimsArrays = "0.6"
3637
SparseArraysBase = "0.5"
38+
TensorAlgebra = "0.2.10"
3739
UnallocatedArrays = "0.1.1"
3840
UnspecifiedTypes = "0.1.1"
3941
VectorInterface = "0.5"

src/quirks.jl

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,46 +20,30 @@ hasqns(a::AbstractITensor) = all(hasqns, inds(a))
2020
# TODO: Investigate this and see if we can get rid of it.
2121
Base.Broadcast.extrude(a::AbstractITensor) = a
2222

23-
# TODO: This is just a stand-in for truncated SVD
24-
# that only makes use of `maxdim`, just to get some
25-
# functionality running in `ITensorMPS.jl`.
26-
# Define a proper truncated SVD in
27-
# `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
28-
function svd_truncated(a::AbstractITensor, codomain_inds; maxdim)
29-
U, S, V = svd(a, codomain_inds)
30-
r = Base.OneTo(min(maxdim, minimum(Int.(size(S)))))
31-
u = commonind(U, S)
32-
v = commonind(V, S)
33-
us = uniqueinds(U, S)
34-
vs = uniqueinds(V, S)
35-
U′ = U[(us .=> :)..., u => r]
36-
S′ = S[u => r, v => r]
37-
V′ = V[v => r, (vs .=> :)...]
38-
return U′, S′, V′
23+
function translate_factorize_kwargs(;
24+
# ITensors.jl kwargs.
25+
ortho=nothing,
26+
cutoff=nothing,
27+
maxdim=nothing,
28+
# MatrixAlgebraKit.jl/TensorAlgebra.jl kwargs.
29+
orth=nothing,
30+
trunc=nothing,
31+
kwargs...,
32+
)
33+
@show ortho, cutoff, maxdim
34+
@show orth, trunc
35+
@show kwargs
36+
return error()
3937
end
4038

41-
using LinearAlgebra: qr, svd
42-
# TODO: Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
43-
function factorize(
44-
a::AbstractITensor, codomain_inds; maxdim=nothing, cutoff=nothing, ortho="left", kwargs...
45-
)
46-
# TODO: Perform this intersection in `TensorAlgebra.qr`/`TensorAlgebra.svd`?
47-
# See https://github.com/ITensor/NamedDimsArrays.jl/issues/22.
48-
codomain_inds′ = if ortho == "left"
49-
intersect(inds(a), codomain_inds)
50-
elseif ortho == "right"
51-
setdiff(inds(a), codomain_inds)
52-
else
53-
error("Bad `ortho` input.")
54-
end
55-
F1, F2 = if isnothing(maxdim) && isnothing(cutoff)
56-
qr(a, codomain_inds′)
57-
else
58-
U, S, V = svd_truncated(a, codomain_inds′; maxdim)
59-
U, S * V
60-
end
61-
if ortho == "right"
62-
F2, F1 = F1, F2
63-
end
64-
return F1, F2, (; truncerr=zero(Bool),)
39+
using TensorAlgebra: TensorAlgebra, factorize
40+
function TensorAlgebra.factorize(a::AbstractITensor, codomain_inds, domain_inds; kwargs...)
41+
return invoke(
42+
factorize,
43+
Tuple{AbstractNamedDimsArray,Any,Any},
44+
a,
45+
codomain_inds,
46+
domain_inds;
47+
translate_factorize_kwargs(; kwargs...)...,
48+
)
6549
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
66
GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
77
ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
910
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
1011
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1112
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1213
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1314
SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e"
15+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1416
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1517

1618
[compat]

test/test_basics.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ using ITensorBase:
1919
using NamedDimsArrays: dename, name, named
2020
using SparseArraysBase: oneelement
2121
using SymmetrySectors: U1
22+
using LinearAlgebra: factorize
2223
using Test: @test, @test_broken, @test_throws, @testset
2324

2425
@testset "ITensorBase" begin
@@ -164,4 +165,10 @@ using Test: @test, @test_broken, @test_throws, @testset
164165
@test hasqns(j)
165166
@test hasqns(a)
166167
end
168+
@testset "factorize" begin
169+
i = Index(2)
170+
j = Index(2)
171+
a = randn(i, j)
172+
x, y = factorize(a, (i,))
173+
end
167174
end

0 commit comments

Comments
 (0)