Skip to content

Commit 5aa3d31

Browse files
authored
Add tests for svd_trunc (#32)
1 parent ca18629 commit 5aa3d31

File tree

4 files changed

+34
-5
lines changed

4 files changed

+34
-5
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.25"
4+
version = "0.1.26"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
2323
[compat]
2424
Adapt = "4.3"
2525
BlockArrays = "1.6"
26-
BlockSparseArrays = "0.8"
26+
BlockSparseArrays = "0.8.1"
2727
DerivableInterfaces = "0.5"
2828
DiagonalArrays = "0.3.5"
2929
FillArrays = "1.13"

src/kroneckerarray.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22
function _convert(A::Type{<:AbstractArray}, a::AbstractArray)
33
return convert(A, a)
44
end
5+
# Custom `_convert` works around the issue that
6+
# `convert(::Type{<:Diagonal}, ::AbstractMatrix)` isnt' defined
7+
# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
8+
# https://github.com/JuliaLang/julia/pull/52487).
9+
# TODO: Delete once we drop support for Julia v1.10.
10+
using LinearAlgebra: LinearAlgebra, Diagonal, diag, isdiag
11+
_construct(A::Type{<:Diagonal}, a::AbstractMatrix) = A(diag(a))
12+
function _convert(A::Type{<:Diagonal}, a::AbstractMatrix)
13+
LinearAlgebra.checksquare(a)
14+
return isdiag(a) ? _construct(A, a) : throw(InexactError(:convert, A, a))
15+
end
516

617
struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N}
718
a::A
@@ -39,7 +50,7 @@ function Base.copyto!(dest::KroneckerArray, src::KroneckerArray)
3950
end
4051

4152
function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where {T,N,A,B}
42-
return KroneckerArray(convert(A, arg1(a)), convert(B, arg2(a)))
53+
return KroneckerArray(_convert(A, arg1(a)), _convert(B, arg2(a)))
4354
end
4455

4556
# Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`.

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
2121
Adapt = "4"
2222
Aqua = "0.8"
2323
BlockArrays = "1.6"
24-
BlockSparseArrays = "0.8"
24+
BlockSparseArrays = "0.8.1"
2525
DerivableInterfaces = "0.5"
2626
DiagonalArrays = "0.3.7"
2727
FillArrays = "1"

test/test_blocksparsearrays.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using FillArrays: Eye, SquareEye
66
using JLArrays: JLArray
77
using KroneckerArrays: KroneckerArray, , ×, arg1, arg2
88
using LinearAlgebra: norm
9-
using MatrixAlgebraKit: svd_compact
9+
using MatrixAlgebraKit: svd_compact, svd_trunc
10+
using StableRNGs: StableRNG
1011
using Test: @test, @test_broken, @testset
1112
using TestExtras: @constinferred
1213

@@ -273,6 +274,23 @@ end
273274
# Broken operations
274275
@test_broken a[Block.(1:2), Block(2)]
275276

277+
# svd_trunc
278+
dev = adapt(arrayt)
279+
r = @constinferred blockrange([2 × 2, 3 × 3])
280+
rng = StableRNG(1234)
281+
d = Dict(
282+
Block(1, 1) => Eye{elt}(2, 2) randn(rng, elt, 2, 2),
283+
Block(2, 2) => Eye{elt}(3, 3) randn(rng, elt, 3, 3),
284+
)
285+
a = @constinferred dev(blocksparse(d, r, r))
286+
if arrayt === Array
287+
u, s, v = svd_trunc(a; trunc=(; maxrank=6))
288+
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=5))
289+
@test Matrix(u * s * v) u′ * s′ * v′
290+
else
291+
@test_broken svd_trunc(a; trunc=(; maxrank=6))
292+
end
293+
276294
@testset "Block deficient" begin
277295
da = Dict(Block(1, 1) => Eye{elt}(2, 2) dev(randn(elt, 2, 2)))
278296
a = @constinferred dev(blocksparse(da, r, r))

0 commit comments

Comments
 (0)