Skip to content

Commit 3045584

Browse files
authored
Fix truncated SVD (#39)
1 parent b5dc783 commit 3045584

File tree

4 files changed

+21
-26
lines changed

4 files changed

+21
-26
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.6"
4+
version = "0.4.7"
55

66
[deps]
77
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
@@ -25,7 +25,7 @@ GradedArraysTensorAlgebraExt = "TensorAlgebra"
2525

2626
[compat]
2727
BlockArrays = "1.6.0"
28-
BlockSparseArrays = "0.6.5"
28+
BlockSparseArrays = "0.7.0"
2929
Compat = "4.16.0"
3030
DerivableInterfaces = "0.4.4"
3131
FillArrays = "1.13.0"

src/factorizations.jl

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,29 +42,6 @@ end
4242

4343
const TGradedUSVᴴ = Tuple{<:GradedMatrix,<:GradedMatrix,<:GradedMatrix}
4444

45-
function BlockSparseArrays.similar_truncate(
46-
::typeof(svd_trunc!),
47-
(U, S, Vᴴ)::TGradedUSVᴴ,
48-
strategy::BlockPermutedDiagonalTruncationStrategy,
49-
indexmask=MatrixAlgebraKit.findtruncated(diagview(S), strategy),
50-
)
51-
u_axis, v_axis = axes(S)
52-
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
53-
s_lengths = map(counter, blocks(u_axis))
54-
u_sectors = sectors(u_axis) .=> s_lengths
55-
v_sectors = sectors(v_axis) .=> s_lengths
56-
u_sectors_filtered = filter(>(0) last, u_sectors)
57-
v_sectors_filtered = filter(>(0) last, v_sectors)
58-
u_axis′ = gradedrange(u_sectors_filtered)
59-
u_axis = isdual(u_axis) ? dual(u_axis′) : u_axis′
60-
v_axis′ = gradedrange(v_sectors_filtered)
61-
v_axis = isdual(v_axis) ? dual(v_axis′) : v_axis′
62-
= similar(U, axes(U, 1), dual(u_axis))
63-
= similar(S, u_axis, v_axis)
64-
Ṽᴴ = similar(Vᴴ, dual(v_axis), axes(Vᴴ, 2))
65-
return Ũ, S̃, Ṽᴴ
66-
end
67-
6845
function BlockSparseArrays.similar_output(
6946
::typeof(qr_compact!), A::GradedMatrix, R_axis, alg::BlockPermutedDiagonalAlgorithm
7047
)

src/gradedunitrange.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,21 @@ function BlockSparseArrays.blockedunitrange_getindices(
320320
new_range = blockedrange(new_first, length.(new_axes))
321321
return GradedUnitRange(new_axes, new_range)
322322
end
323+
324+
using BlockArrays: BlockedVector
325+
function BlockSparseArrays.blockedunitrange_getindices(
326+
a::AbstractGradedUnitRange, indices::AbstractVector{Bool}
327+
)
328+
blocked_indices = BlockedVector(indices, axes(a))
329+
bs = map(Base.OneTo(blocklength(blocked_indices))) do b
330+
binds = blocked_indices[Block(b)]
331+
bstart = blockfirsts(only(axes(blocked_indices)))[b]
332+
return findall(binds) .+ (bstart - 1)
333+
end
334+
keep = map(!isempty, bs)
335+
secs = sectors(a)[keep]
336+
bs = bs[keep]
337+
r = gradedrange(secs .=> length.(bs); isdual=isdual(a))
338+
I = mortar(bs, (r,))
339+
return I
340+
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1717
[compat]
1818
Aqua = "0.8.11"
1919
BlockArrays = "1.6.0"
20-
BlockSparseArrays = "0.6"
20+
BlockSparseArrays = "0.7.0"
2121
GradedArrays = "0.4"
2222
LinearAlgebra = "1.10.0"
2323
MatrixAlgebraKit = "0.2"

0 commit comments

Comments
 (0)