diff --git a/Project.toml b/Project.toml index 3de963a..42f0d4c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.2.0" +version = "0.2.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/matrixalgebrakit.jl b/src/matrixalgebrakit.jl index e46215b..84eaa26 100644 --- a/src/matrixalgebrakit.jl +++ b/src/matrixalgebrakit.jl @@ -209,36 +209,40 @@ struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy strategy::T end -## using FillArrays: OnesVector -## const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B} -## const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} -## const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} +using FillArrays: OnesVector +const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B} +const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} +const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} axis(a) = only(axes(a)) -## # Convert indices determined with a generic call to `findtruncated` to indices -## # more suited for a KroneckerVector. -## function to_truncated_indices(values::OnesKroneckerVector, I) -## prods = cartesianproduct(axis(values))[I] -## I_id = only(to_indices(arg1(values), (:,))) -## I_data = unique(arg2.(prods)) -## # Drop truncations that occur within the identity. -## I_data = filter(I_data) do i -## return count(x -> arg2(x) == i, prods) == length(arg2(values)) -## end -## return I_id × I_data -## end -## function to_truncated_indices(values::KroneckerOnesVector, I) -## #I = findtruncated(Vector(values), strategy.strategy) -## prods = cartesianproduct(axis(values))[I] -## I_data = unique(arg1.(prods)) -## # Drop truncations that occur within the identity. -## I_data = filter(I_data) do i -## return count(x -> arg1(x) == i, prods) == length(arg2(values)) -## end -## I_id = only(to_indices(arg2(values), (:,))) -## return I_data × I_id -## end +# Convert indices determined with a generic call to `findtruncated` to indices +# more suited for a KroneckerVector. +function to_truncated_indices(values::OnesKroneckerVector, I) + prods = cartesianproduct(axis(values))[I] + I_id = only(to_indices(arg1(values), (:,))) + I_data = unique(arg2.(prods)) + # Drop truncations that occur within the identity. + I_data = filter(I_data) do i + return count(x -> arg2(x) == i, prods) == length(arg2(values)) + end + return I_id × I_data +end +function to_truncated_indices(values::KroneckerOnesVector, I) + #I = findtruncated(Vector(values), strategy.strategy) + prods = cartesianproduct(axis(values))[I] + I_data = unique(arg1.(prods)) + # Drop truncations that occur within the identity. + I_data = filter(I_data) do i + return count(x -> arg1(x) == i, prods) == length(arg2(values)) + end + I_id = only(to_indices(arg2(values), (:,))) + return I_data × I_id +end +# Fix ambiguity error. +function to_truncated_indices(values::OnesVectorOnesVector, I) + return throw(ArgumentError("Not implemented")) +end function to_truncated_indices(values::KroneckerVector, I) return throw(ArgumentError("Not implemented")) end diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index 57ad2e9..69be9f4 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -395,24 +395,22 @@ end @test b[Block(1)] == a[Block(1, 2)] @test b[Block(2)] == a[Block(2, 2)] - ## TODO: Broken, fix and re-enable. - @test_broken false - ## # svd_trunc - ## dev = adapt(arrayt) - ## r = @constinferred blockrange([2 × 2, 3 × 3]) - ## rng = StableRNG(1234) - ## d = Dict( - ## Block(1, 1) => δ(elt, (2, 2)) ⊗ randn(rng, elt, 2, 2), - ## Block(2, 2) => δ(elt, (3, 3)) ⊗ randn(rng, elt, 3, 3), - ## ) - ## a = @constinferred dev(blocksparse(d, (r, r))) - ## if arrayt === Array - ## u, s, v = svd_trunc(a; trunc=(; maxrank=6)) - ## u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=5)) - ## @test Matrix(u * s * v) ≈ u′ * s′ * v′ - ## else - ## @test_broken svd_trunc(a; trunc=(; maxrank=6)) - ## end + # svd_trunc + dev = adapt(arrayt) + r = @constinferred blockrange([2 × 2, 3 × 3]) + rng = StableRNG(1234) + d = Dict( + Block(1, 1) => δ(elt, (2, 2)) ⊗ randn(rng, elt, 2, 2), + Block(2, 2) => δ(elt, (3, 3)) ⊗ randn(rng, elt, 3, 3), + ) + a = @constinferred dev(blocksparse(d, (r, r))) + if arrayt === Array + u, s, v = svd_trunc(a; trunc=(; maxrank=6)) + u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=5)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + else + @test_broken svd_trunc(a; trunc=(; maxrank=6)) + end @testset "Block deficient" begin da = Dict(Block(1, 1) => δ(elt, (2, 2)) ⊗ dev(randn(elt, 2, 2))) diff --git a/test/test_matrixalgebrakit_delta.jl b/test/test_matrixalgebrakit_delta.jl index 6f43e4f..f4a8a61 100644 --- a/test/test_matrixalgebrakit_delta.jl +++ b/test/test_matrixalgebrakit_delta.jl @@ -73,27 +73,26 @@ herm(a) = parent(hermitianpart(a)) @test arguments(v, 2) isa DeltaMatrix{elt} end - ## TODO: Broken, need to fix truncation. - ## for f in (eig_trunc, eigh_trunc) - ## a = δ(3, 3) ⊗ parent(hermitianpart(randn(3, 3))) - ## d, v = f(a; trunc=(; maxrank=7)) - ## @test a * v ≈ v * d - ## @test arguments(d, 1) isa DeltaMatrix - ## @test arguments(v, 1) isa DeltaMatrix - ## @test size(d) == (6, 6) - ## @test size(v) == (9, 6) + for f in (eig_trunc, eigh_trunc) + a = δ(3, 3) ⊗ parent(hermitianpart(randn(3, 3))) + d, v = f(a; trunc=(; maxrank=7)) + @test a * v ≈ v * d + @test arguments(d, 1) isa DeltaMatrix + @test arguments(v, 1) isa DeltaMatrix + @test size(d) == (6, 6) + @test size(v) == (9, 6) - ## a = parent(hermitianpart(randn(3, 3))) ⊗ δ(3, 3) - ## d, v = f(a; trunc=(; maxrank=7)) - ## @test a * v ≈ v * d - ## @test arguments(d, 2) isa DeltaMatrix - ## @test arguments(v, 2) isa DeltaMatrix - ## @test size(d) == (6, 6) - ## @test size(v) == (9, 6) + a = parent(hermitianpart(randn(3, 3))) ⊗ δ(3, 3) + d, v = f(a; trunc=(; maxrank=7)) + @test a * v ≈ v * d + @test arguments(d, 2) isa DeltaMatrix + @test arguments(v, 2) isa DeltaMatrix + @test size(d) == (6, 6) + @test size(v) == (9, 6) - ## a = δ(3, 3) ⊗ δ(3, 3) - ## @test_throws ArgumentError f(a) - ## end + a = δ(3, 3) ⊗ δ(3, 3) + @test_throws ArgumentError f(a) + end for f in (eig_vals, eigh_vals) a = δ(3, 3) ⊗ parent(hermitianpart(randn(3, 3))) @@ -183,46 +182,44 @@ herm(a) = parent(hermitianpart(a)) end end - ## TODO: Need to implement truncation. - ## # svd_trunc - ## for elt in (Float32, ComplexF32) - ## a = δ(elt, 3, 3) ⊗ randn(elt, 3, 3) - ## # TODO: Type inference is broken for `svd_trunc`, - ## # look into fixing it. - ## # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) - ## u, s, v = svd_trunc(a; trunc=(; maxrank=7)) - ## @test eltype(u) === elt - ## @test eltype(s) === real(elt) - ## @test eltype(v) === elt - ## u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) - ## @test Matrix(u * s * v) ≈ u′ * s′ * v′ - ## @test arguments(u, 1) isa DeltaMatrix{elt} - ## @test arguments(s, 1) isa DeltaMatrix{real(elt)} - ## @test arguments(v, 1) isa DeltaMatrix{elt} - ## @test size(u) == (9, 6) - ## @test size(s) == (6, 6) - ## @test size(v) == (6, 9) - ## end + # svd_trunc + for elt in (Float32, ComplexF32) + a = δ(elt, 3, 3) ⊗ randn(elt, 3, 3) + # TODO: Type inference is broken for `svd_trunc`, + # look into fixing it. + # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) + u, s, v = svd_trunc(a; trunc=(; maxrank=7)) + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + @test arguments(u, 1) isa DeltaMatrix{elt} + @test arguments(s, 1) isa DeltaMatrix{real(elt)} + @test arguments(v, 1) isa DeltaMatrix{elt} + @test size(u) == (9, 6) + @test size(s) == (6, 6) + @test size(v) == (6, 9) + end - ## TODO: Need to implement truncation. - ## for elt in (Float32, ComplexF32) - ## a = randn(elt, 3, 3) ⊗ δ(elt, 3, 3) - ## # TODO: Type inference is broken for `svd_trunc`, - ## # look into fixing it. - ## # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) - ## u, s, v = svd_trunc(a; trunc=(; maxrank=7)) - ## @test eltype(u) === elt - ## @test eltype(s) === real(elt) - ## @test eltype(v) === elt - ## u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) - ## @test Matrix(u * s * v) ≈ u′ * s′ * v′ - ## @test arguments(u, 2) isa DeltaMatrix{elt} - ## @test arguments(s, 2) isa DeltaMatrix{real(elt)} - ## @test arguments(v, 2) isa DeltaMatrix{elt} - ## @test size(u) == (9, 6) - ## @test size(s) == (6, 6) - ## @test size(v) == (6, 9) - ## end + for elt in (Float32, ComplexF32) + a = randn(elt, 3, 3) ⊗ δ(elt, 3, 3) + # TODO: Type inference is broken for `svd_trunc`, + # look into fixing it. + # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) + u, s, v = svd_trunc(a; trunc=(; maxrank=7)) + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + @test arguments(u, 2) isa DeltaMatrix{elt} + @test arguments(s, 2) isa DeltaMatrix{real(elt)} + @test arguments(v, 2) isa DeltaMatrix{elt} + @test size(u) == (9, 6) + @test size(s) == (6, 6) + @test size(v) == (6, 9) + end a = δ(3, 3) ⊗ δ(3, 3) @test_broken svd_trunc(a)