Skip to content

Commit 326fa8b

Browse files
authored
Start using TensorAlgebra.factorize (#64)
1 parent b6d96bc commit 326fa8b

File tree

4 files changed

+63
-41
lines changed

4 files changed

+63
-41
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorBase"
22
uuid = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -10,6 +10,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1212
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
13+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1314
UnallocatedArrays = "43c9e47c-e622-40fb-bf18-a09fc8c466b6"
1415
UnspecifiedTypes = "42b3faec-625b-4613-8ddc-352bf9672b8d"
1516
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
@@ -34,6 +35,7 @@ LinearAlgebra = "1.10"
3435
MapBroadcast = "0.1.5"
3536
NamedDimsArrays = "0.7"
3637
SparseArraysBase = "0.5"
38+
TensorAlgebra = "0.3"
3739
UnallocatedArrays = "0.1.1"
3840
UnspecifiedTypes = "0.1.1"
3941
VectorInterface = "0.5"

src/quirks.jl

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,46 +20,37 @@ hasqns(a::AbstractITensor) = all(hasqns, inds(a))
2020
# TODO: Investigate this and see if we can get rid of it.
2121
Base.Broadcast.extrude(a::AbstractITensor) = a
2222

23-
# TODO: This is just a stand-in for truncated SVD
24-
# that only makes use of `maxdim`, just to get some
25-
# functionality running in `ITensorMPS.jl`.
26-
# Define a proper truncated SVD in
27-
# `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
28-
function svd_truncated(a::AbstractITensor, codomain_inds; maxdim)
29-
U, S, V = svd(a, codomain_inds)
30-
r = Base.OneTo(min(maxdim, minimum(Int.(size(S)))))
31-
u = commonind(U, S)
32-
v = commonind(V, S)
33-
us = uniqueinds(U, S)
34-
vs = uniqueinds(V, S)
35-
U′ = U[(us .=> :)..., u => r]
36-
S′ = S[u => r, v => r]
37-
V′ = V[v => r, (vs .=> :)...]
38-
return U′, S′, V′
39-
end
23+
# See: https://github.com/JuliaLang/julia/blob/v1.11.4/base/namedtuple.jl#L269
24+
# `filter(f, ::NamedTuple)` is available in Julia v1.11, delete once
25+
# we drop support for Julia v1.10.
26+
filter_namedtuple(f, xs::NamedTuple) = xs[filter(k -> f(xs[k]), keys(xs))]
4027

41-
using LinearAlgebra: qr, svd
42-
# TODO: Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`.
43-
function factorize(
44-
a::AbstractITensor, codomain_inds; maxdim=nothing, cutoff=nothing, ortho="left", kwargs...
28+
function translate_factorize_kwargs(;
29+
# MatrixAlgebraKit.jl/TensorAlgebra.jl kwargs.
30+
orth=nothing,
31+
rtol=nothing,
32+
maxrank=nothing,
33+
# ITensors.jl kwargs.
34+
ortho=nothing,
35+
cutoff=nothing,
36+
maxdim=nothing,
37+
kwargs...,
4538
)
46-
# TODO: Perform this intersection in `TensorAlgebra.qr`/`TensorAlgebra.svd`?
47-
# See https://github.com/ITensor/NamedDimsArrays.jl/issues/22.
48-
codomain_inds′ = if ortho == "left"
49-
intersect(inds(a), codomain_inds)
50-
elseif ortho == "right"
51-
setdiff(inds(a), codomain_inds)
52-
else
53-
error("Bad `ortho` input.")
54-
end
55-
F1, F2 = if isnothing(maxdim) && isnothing(cutoff)
56-
qr(a, codomain_inds′)
57-
else
58-
U, S, V = svd_truncated(a, codomain_inds′; maxdim)
59-
U, S * V
60-
end
61-
if ortho == "right"
62-
F2, F1 = F1, F2
63-
end
64-
return F1, F2, (; truncerr=zero(Bool),)
39+
orth::Symbol = @something orth ortho :left
40+
rtol = @something rtol cutoff Some(nothing)
41+
maxrank = @something maxrank maxdim Some(nothing)
42+
!isnothing(maxrank) && error("`maxrank` not supported yet.")
43+
return filter_namedtuple(!isnothing, (; orth, rtol, maxrank, kwargs...))
44+
end
45+
46+
using TensorAlgebra: TensorAlgebra, factorize
47+
function TensorAlgebra.factorize(a::AbstractITensor, codomain_inds, domain_inds; kwargs...)
48+
return invoke(
49+
factorize,
50+
Tuple{AbstractNamedDimsArray,Any,Any},
51+
a,
52+
codomain_inds,
53+
domain_inds;
54+
translate_factorize_kwargs(; kwargs...)...,
55+
)
6556
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1111
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1212
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1313
SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e"
14+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1415
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1516

1617
[compat]
@@ -25,4 +26,5 @@ SafeTestsets = "0.1"
2526
SparseArraysBase = "0.5"
2627
Suppressor = "0.2"
2728
SymmetrySectors = "0.1"
29+
TensorAlgebra = "0.3"
2830
Test = "1.10"

test/test_basics.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ using ITensorBase:
1919
using NamedDimsArrays: dename, name, named
2020
using SparseArraysBase: oneelement
2121
using SymmetrySectors: U1
22+
using LinearAlgebra: factorize
2223
using Test: @test, @test_broken, @test_throws, @testset
2324

25+
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
2426
@testset "ITensorBase" begin
2527
@testset "Basics" begin
2628
elt = Float64
@@ -164,4 +166,29 @@ using Test: @test, @test_broken, @test_throws, @testset
164166
@test hasqns(j)
165167
@test hasqns(a)
166168
end
169+
@testset "factorize" for elt in elts
170+
i = Index(2)
171+
j = Index(2)
172+
a = randn(elt, i, j)
173+
x, y = factorize(a, (i,))
174+
@test a x * y
175+
@test x isa ITensor
176+
@test y isa ITensor
177+
@test i inds(x)
178+
@test j inds(y)
179+
@test eltype(x) === elt
180+
@test eltype(y) === elt
181+
@test Int.(Tuple(size(x))) == (2, 2)
182+
@test Int.(Tuple(size(y))) == (2, 2)
183+
184+
i = Index(2)
185+
j = Index(2)
186+
a = randn(elt, i) * randn(elt, j)
187+
for kwargs in ((; rtol=1e-2), (; cutoff=1e-2))
188+
x, y = factorize(a, (i,); kwargs...)
189+
@test a x * y
190+
@test Int.(Tuple(size(x))) == (2, 1)
191+
@test Int.(Tuple(size(y))) == (1, 2)
192+
end
193+
end
167194
end

0 commit comments

Comments
 (0)