Skip to content

Commit 885175c

Browse files
JoeyT1994mtfishman
andauthored
Fix sqrt decomp (#1669)
Co-authored-by: Matt Fishman <[email protected]>
1 parent 33eda1c commit 885175c

File tree

5 files changed

+17
-7
lines changed

5 files changed

+17
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensors"
22
uuid = "9136182c-28ba-11e9-034c-db9fb085ebd5"
33
authors = ["Matthew Fishman <[email protected]>", "Miles Stoudenmire <[email protected]>"]
4-
version = "0.9.9"
4+
version = "0.9.10"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/tensor_operations/matrix_decomposition.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,8 +581,8 @@ using NDTensors: map_diag!
581581
function sqrt_decomp(D::ITensor, u::Index, v::Index)
582582
(storage(D) isa Union{Diag,DiagBlockSparse}) ||
583583
error("Must be a diagonal matrix ITensor.")
584-
sqrtDL = diag_itensor(u, dag(u)')
585-
sqrtDR = diag_itensor(v, dag(v)')
584+
sqrtDL = adapt(datatype(D), diag_itensor(u, dag(u)'))
585+
sqrtDR = adapt(datatype(D), diag_itensor(v, dag(v)'))
586586
map_diag!(sqrt abs, sqrtDL, D)
587587
map_diag!(sqrt abs, sqrtDR, D)
588588
δᵤᵥ = copy(D)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
23
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
34
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
45
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
@@ -7,6 +8,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
78
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
89
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
910
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
11+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
1012
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1113
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
1214
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"

test/base/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
23
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
34
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
45
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"

test/base/test_svd.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using ITensors
1+
using Adapt: adapt
2+
using ITensors: datatype
3+
using JLArrays: JLArray
24
using Test
35
using Suppressor
46

@@ -205,9 +207,14 @@ include(joinpath(@__DIR__, "utils", "util.jl"))
205207
end
206208

207209
for dir in [ITensors.Out, ITensors.In]
208-
L, R, spec = ITensors.factorize_svd(A, l1, l2; dir, ortho="none")
209-
@test dir == ITensors.dir(commonind(L, R))
210-
@test norm(L * R - A) <= 1e-14
210+
for arrayt in (Array, JLArray)
211+
A′ = adapt(arrayt, A)
212+
L, R, spec = ITensors.factorize_svd(A, l1, l2; dir, ortho="none")
213+
@test datatype(L) == datatype(A)
214+
@test datatype(R) == datatype(A)
215+
@test dir == ITensors.dir(commonind(L, R))
216+
@test norm(L * R - A) <= 1e-14
217+
end
211218
end
212219
end
213220

0 commit comments

Comments
 (0)