@@ -45,69 +45,11 @@ function MatrixAlgebraKit.findtruncated(
45
45
return indexmask
46
46
end
47
47
48
- function similar_truncate (
49
- :: typeof (svd_trunc!),
50
- (U, S, Vᴴ):: TBlockUSV ᴴ,
51
- strategy:: BlockPermutedDiagonalTruncationStrategy ,
52
- indexmask= MatrixAlgebraKit. findtruncated (diagview (S), strategy),
53
- )
54
- ax = axes (S, 1 )
55
- counter = Base. Fix1 (count, Base. Fix1 (getindex, indexmask))
56
- s_lengths = filter! (> (0 ), map (counter, blocks (ax)))
57
- s_axis = blockedrange (s_lengths)
58
- Ũ = similar (U, axes (U, 1 ), s_axis)
59
- S̃ = similar (S, s_axis, s_axis)
60
- Ṽᴴ = similar (Vᴴ, s_axis, axes (Vᴴ, 2 ))
61
- return Ũ, S̃, Ṽᴴ
62
- end
63
-
64
48
function MatrixAlgebraKit. truncate! (
65
49
:: typeof (svd_trunc!),
66
50
(U, S, Vᴴ):: TBlockUSV ᴴ,
67
51
strategy:: BlockPermutedDiagonalTruncationStrategy ,
68
52
)
69
- indexmask = MatrixAlgebraKit. findtruncated (diagview (S), strategy)
70
-
71
- # first determine the block structure of the output to avoid having assumptions on the
72
- # data structures
73
- Ũ, S̃, Ṽᴴ = similar_truncate (svd_trunc!, (U, S, Vᴴ), strategy, indexmask)
74
-
75
- # then loop over the blocks and assign the data
76
- # TODO : figure out if we can presort and loop over the blocks -
77
- # for now this has issues with missing blocks
78
- bI_Us = collect (eachblockstoredindex (U))
79
- bI_Ss = collect (eachblockstoredindex (S))
80
- bI_Vᴴs = collect (eachblockstoredindex (Vᴴ))
81
-
82
- I′ = 0 # number of skipped blocks that got fully truncated
83
- ax = axes (S, 1 )
84
- for I in 1 : blocksize (ax, 1 )
85
- b = ax[Block (I)]
86
- mask = indexmask[b]
87
-
88
- if ! any (mask)
89
- I′ += 1
90
- continue
91
- end
92
-
93
- bU_id = @something findfirst (x -> last (Tuple (x)) == Block (I), bI_Us) error (
94
- " No U-block found for $I "
95
- )
96
- bU = Tuple (bI_Us[bU_id])
97
- Ũ[bU[1 ], bU[2 ] - Block (I′)] = view (U, bU... )[:, mask]
98
-
99
- bVᴴ_id = @something findfirst (x -> first (Tuple (x)) == Block (I), bI_Vᴴs) error (
100
- " No Vᴴ-block found for $I "
101
- )
102
- bVᴴ = Tuple (bI_Vᴴs[bVᴴ_id])
103
- Ṽᴴ[bVᴴ[1 ] - Block (I′), bVᴴ[2 ]] = view (Vᴴ, bVᴴ... )[mask, :]
104
-
105
- bS_id = findfirst (x -> last (Tuple (x)) == Block (I), bI_Ss)
106
- if ! isnothing (bS_id)
107
- bS = Tuple (bI_Ss[bS_id])
108
- S̃[(bS .- Block (I′)). .. ] = Diagonal (diagview (view (S, bS... ))[mask])
109
- end
110
- end
111
-
112
- return Ũ, S̃, Ṽᴴ
53
+ I = MatrixAlgebraKit. findtruncated (diagview (S), strategy)
54
+ return (U[:, I], S[I, I], Vᴴ[I, :])
113
55
end
0 commit comments