Skip to content

Commit cb52828

Browse files
committed
Truncate
1 parent e392cc8 commit cb52828

File tree

5 files changed

+130
-43
lines changed

5 files changed

+130
-43
lines changed

src/KroneckerArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ include("matrixalgebrakit.jl")
99
include("fillarrays/kroneckerarray.jl")
1010
include("fillarrays/linearalgebra.jl")
1111
include("fillarrays/matrixalgebrakit.jl")
12+
include("fillarrays/matrixalgebrakit_truncate.jl")
1213

1314
end

src/fillarrays/matrixalgebrakit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ function supremum(r1::AbstractRange, r2::AbstractUnitRange)
1515
end
1616
end
1717

18+
# Allow customization for `Eye`.
19+
_diagview(a::Eye) = parent(a)
20+
1821
function _copy_input(f::F, a::Eye) where {F}
1922
return a
2023
end
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
using MatrixAlgebraKit: TruncationStrategy, diagview, findtruncated, truncate!
2+
3+
struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
4+
strategy::T
5+
end
6+
7+
# Avoid instantiating the identity.
8+
function Base.getindex(a::EyeKronecker, I::Vararg{CartesianProduct{Colon},2})
9+
return a.a a.b[I[1].b, I[2].b]
10+
end
11+
function Base.getindex(a::KroneckerEye, I::Vararg{CartesianProduct{<:Any,Colon},2})
12+
return a.a[I[1].a, I[2].a] a.b
13+
end
14+
function Base.getindex(a::EyeEye, I::Vararg{CartesianProduct{Colon,Colon},2})
15+
return a
16+
end
17+
18+
using FillArrays: OnesVector
19+
const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B}
20+
const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
21+
const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
22+
23+
function MatrixAlgebraKit.findtruncated(
24+
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
25+
)
26+
I = findtruncated(Vector(values), strategy.strategy)
27+
prods = collect(only(axes(values)).product)[I]
28+
I_data = unique(map(x -> x.a, prods))
29+
# Drop truncations that occur within the identity.
30+
I_data = filter(I_data) do i
31+
return count(x -> x.a == i, prods) == length(values.a)
32+
end
33+
return (:) × I_data
34+
end
35+
function MatrixAlgebraKit.findtruncated(
36+
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
37+
)
38+
I = findtruncated(Vector(values), strategy.strategy)
39+
prods = collect(only(axes(values)).product)[I]
40+
I_data = unique(map(x -> x.b, prods))
41+
# Drop truncations that occur within the identity.
42+
I_data = filter(I_data) do i
43+
return count(x -> x.b == i, prods) == length(values.b)
44+
end
45+
return I_data × (:)
46+
end
47+
function MatrixAlgebraKit.findtruncated(
48+
values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy
49+
)
50+
return throw(ArgumentError("Can't truncate Eye ⊗ Eye."))
51+
end
52+
53+
for f in [:eig_trunc!, :eigh_trunc!]
54+
@eval begin
55+
function MatrixAlgebraKit.truncate!(
56+
::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy
57+
)
58+
return truncate!($f, DV, KroneckerTruncationStrategy(strategy))
59+
end
60+
function MatrixAlgebraKit.truncate!(
61+
::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy
62+
)
63+
I = findtruncated(diagview(D), strategy)
64+
return (D[I, I], V[(:) × (:), I])
65+
end
66+
end
67+
end
68+
69+
function MatrixAlgebraKit.truncate!(
70+
f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy
71+
)
72+
return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy))
73+
end
74+
function MatrixAlgebraKit.truncate!(
75+
::typeof(svd_trunc!),
76+
(U, S, Vᴴ)::NTuple{3,KroneckerMatrix},
77+
strategy::KroneckerTruncationStrategy,
78+
)
79+
I = findtruncated(diagview(S), strategy)
80+
return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)])
81+
end

src/matrixalgebrakit.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ using MatrixAlgebraKit:
3232
truncate!
3333

3434
using MatrixAlgebraKit: MatrixAlgebraKit, diagview
35+
# Allow customization for `Eye`.
36+
_diagview(a::AbstractMatrix) = diagview(a)
3537
function MatrixAlgebraKit.diagview(a::KroneckerMatrix)
36-
return diagview(a.a) diagview(a.b)
38+
return _diagview(a.a) _diagview(a.b)
3739
end
3840

3941
struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm

test/test_fillarrays_matrixalgebrakit.jl

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -176,51 +176,51 @@ herm(a) = parent(hermitianpart(a))
176176
end
177177
end
178178

179-
## # svd_trunc
180-
## for elt in (Float32, ComplexF32)
181-
## a = Eye{elt}(3) ⊗ randn(elt, 3, 3)
182-
## # TODO: Type inference is broken for `svd_trunc`,
183-
## # look into fixing it.
184-
## # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7))
185-
## u, s, v = svd_trunc(a; trunc=(; maxrank=7))
186-
## @test eltype(u) === elt
187-
## @test eltype(s) === real(elt)
188-
## @test eltype(v) === elt
189-
## u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6))
190-
## @test Matrix(u * s * v) ≈ u′ * s′ * v′
191-
## @test arguments(u, 1) isa Eye{elt}
192-
## @test arguments(s, 1) isa Eye{real(elt)}
193-
## @test arguments(v, 1) isa Eye{elt}
194-
## @test size(u) == (9, 6)
195-
## @test size(s) == (6, 6)
196-
## @test size(v) == (6, 9)
197-
## end
179+
# svd_trunc
180+
for elt in (Float32, ComplexF32)
181+
a = Eye{elt}(3, 3) randn(elt, 3, 3)
182+
# TODO: Type inference is broken for `svd_trunc`,
183+
# look into fixing it.
184+
# u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7))
185+
u, s, v = svd_trunc(a; trunc=(; maxrank=7))
186+
@test eltype(u) === elt
187+
@test eltype(s) === real(elt)
188+
@test eltype(v) === elt
189+
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6))
190+
@test Matrix(u * s * v) u′ * s′ * v′
191+
@test arguments(u, 1) isa Eye{elt}
192+
@test arguments(s, 1) isa Eye{real(elt)}
193+
@test arguments(v, 1) isa Eye{elt}
194+
@test size(u) == (9, 6)
195+
@test size(s) == (6, 6)
196+
@test size(v) == (6, 9)
197+
end
198198

199-
## for elt in (Float32, ComplexF32)
200-
## a = randn(elt, 3, 3) ⊗ Eye{elt}(3)
201-
## # TODO: Type inference is broken for `svd_trunc`,
202-
## # look into fixing it.
203-
## # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7))
204-
## u, s, v = svd_trunc(a; trunc=(; maxrank=7))
205-
## @test eltype(u) === elt
206-
## @test eltype(s) === real(elt)
207-
## @test eltype(v) === elt
208-
## u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6))
209-
## @test Matrix(u * s * v) ≈ u′ * s′ * v′
210-
## @test arguments(u, 2) isa Eye{elt}
211-
## @test arguments(s, 2) isa Eye{real(elt)}
212-
## @test arguments(v, 2) isa Eye{elt}
213-
## @test size(u) == (9, 6)
214-
## @test size(s) == (6, 6)
215-
## @test size(v) == (6, 9)
216-
## end
199+
for elt in (Float32, ComplexF32)
200+
a = randn(elt, 3, 3) Eye{elt}(3, 3)
201+
# TODO: Type inference is broken for `svd_trunc`,
202+
# look into fixing it.
203+
# u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7))
204+
u, s, v = svd_trunc(a; trunc=(; maxrank=7))
205+
@test eltype(u) === elt
206+
@test eltype(s) === real(elt)
207+
@test eltype(v) === elt
208+
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6))
209+
@test Matrix(u * s * v) u′ * s′ * v′
210+
@test arguments(u, 2) isa Eye{elt}
211+
@test arguments(s, 2) isa Eye{real(elt)}
212+
@test arguments(v, 2) isa Eye{elt}
213+
@test size(u) == (9, 6)
214+
@test size(s) == (6, 6)
215+
@test size(v) == (6, 9)
216+
end
217217

218-
## a = Eye(3) ⊗ Eye(3)
219-
## @test_throws ArgumentError svd_trunc(a)
218+
a = Eye(3, 3) Eye(3, 3)
219+
@test_throws ArgumentError svd_trunc(a)
220220

221221
# svd_vals
222222
for elt in (Float32, ComplexF32)
223-
a = Eye{elt}(3) randn(elt, 3, 3)
223+
a = Eye{elt}(3, 3) randn(elt, 3, 3)
224224
d = @constinferred svd_vals(a)
225225
d′ = svd_vals(Matrix(a))
226226
@test sort(Vector(d); by=abs) sort(d′; by=abs)
@@ -246,12 +246,12 @@ herm(a) = parent(hermitianpart(a))
246246
end
247247

248248
## # left_null
249-
## a = Eye(3) ⊗ randn(3, 3)
249+
## a = Eye(3, 3) ⊗ randn(3, 3)
250250
## n = @constinferred left_null(a)
251251
## @test norm(n' * a) ≈ 0
252252
## @test arguments(n, 1) isa Eye
253253

254-
## a = randn(3, 3) ⊗ Eye(3)
254+
## a = randn(3, 3) ⊗ Eye(3, 3)
255255
## n = @constinferred left_null(a)
256256
## @test norm(n' * a) ≈ 0
257257
## @test arguments(n, 2) isa Eye

0 commit comments

Comments
 (0)