diff --git a/Project.toml b/Project.toml index d1c03f7..dbbe6bb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.21" +version = "0.1.22" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -23,7 +23,7 @@ KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"] [compat] Adapt = "4.3.0" BlockArrays = "1.6" -BlockSparseArrays = "0.7.21" +BlockSparseArrays = "0.7.22" DerivableInterfaces = "0.5.0" DiagonalArrays = "0.3.5" FillArrays = "1.13.0" diff --git a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl index 9fa1018..b23c869 100644 --- a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl +++ b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl @@ -99,4 +99,10 @@ function (f::GetUnstoredBlock)( return error("Not implemented.") end +using BlockSparseArrays: BlockSparseArrays +using KroneckerArrays: KroneckerArrays, KroneckerVector +function BlockSparseArrays.to_truncated_indices(values::KroneckerVector, I) + return KroneckerArrays.to_truncated_indices(values, I) +end + end diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index e3eb2fe..e0138f3 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -26,10 +26,16 @@ arguments(a::CartesianProduct, n::Int) = arguments(a)[n] arg1(a::CartesianProduct) = a.a arg2(a::CartesianProduct) = a.b +Base.copy(a::CartesianProduct) = copy(arg1(a)) × copy(arg2(a)) + function Base.show(io::IO, a::CartesianProduct) print(io, a.a, " × ", a.b) return nothing end +function Base.show(io::IO, ::MIME"text/plain", a::CartesianProduct) + show(io, a) + return nothing +end ×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b) Base.length(a::CartesianProduct) = length(a.a) * length(a.b) @@ -42,8 +48,38 @@ function Base.getindex(a::CartesianProduct, i::CartesianPair) return arg1(a)[arg1(i)] × arg2(a)[arg2(i)] end function Base.getindex(a::CartesianProduct, i::Int) - I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i]) - return a[I[1] × I[2]] + I = Tuple(CartesianIndices((length(arg2(a)), length(arg1(a))))[i]) + return a[I[2] × I[1]] +end + +struct CartesianProductVector{T,P<:CartesianProduct,V<:AbstractVector{T}} <: + AbstractVector{T} + product::P + values::V +end +cartesianproduct(r::CartesianProductVector) = getfield(r, :product) +unproduct(r::CartesianProductVector) = getfield(r, :values) +Base.length(a::CartesianProductVector) = length(unproduct(a)) +Base.size(a::CartesianProductVector) = (length(a),) +function Base.axes(r::CartesianProductVector) + return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),) +end +function Base.copy(a::CartesianProductVector) + return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a))) +end +function Base.getindex(r::CartesianProductVector, i::Integer) + return unproduct(r)[i] +end + +function Base.show(io::IO, a::CartesianProductVector) + show(io, unproduct(a)) + return nothing +end +function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductVector) + show(io, mime, cartesianproduct(a)) + println(io) + show(io, mime, unproduct(a)) + return nothing end struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <: @@ -60,13 +96,24 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range) arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a)) arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a)) +function Base.show(io::IO, a::CartesianProductUnitRange) + show(io, unproduct(a)) + return nothing +end +function Base.show(io::IO, mime::MIME"text/plain", a::CartesianProductUnitRange) + show(io, mime, cartesianproduct(a)) + println(io) + show(io, mime, unproduct(a)) + return nothing +end + function CartesianProductUnitRange(p::CartesianProduct) return CartesianProductUnitRange(p, Base.OneTo(length(p))) end function CartesianProductUnitRange(a, b) return CartesianProductUnitRange(a × b) end -to_product_indices(a::AbstractUnitRange) = a +to_product_indices(a::AbstractVector) = a to_product_indices(i::Integer) = Base.OneTo(i) cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b)) function cartesianrange(p::CartesianPair) @@ -94,10 +141,16 @@ function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::Carte return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i)) end +function Base.getindex(a::CartesianProductUnitRange, I::CartesianProduct) + prod = cartesianproduct(a) + prod_I = arg1(prod)[arg1(I)] × arg2(prod)[arg2(I)] + return CartesianProductVector(prod_I, map(Base.Fix1(getindex, a), I)) +end + # Reverse map from CartesianPair to linear index in the range. function Base.getindex(inds::CartesianProductUnitRange, i::CartesianPair) - i′ = (findfirst(==(arg1(i)), arg1(inds)), findfirst(==(arg2(i)), arg2(inds))) - return inds[LinearIndices((length(arg1(inds)), length(arg2(inds))))[i′...]] + i′ = (findfirst(==(arg2(i)), arg2(inds)), findfirst(==(arg1(i)), arg1(inds))) + return inds[LinearIndices((length(arg2(inds)), length(arg1(inds))))[i′...]] end using Base.Broadcast: DefaultArrayStyle diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index fc40ef7..f943ce0 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -22,7 +22,13 @@ const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatr const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} _getindex(a::Eye, I1::Colon, I2::Colon) = a +_getindex(a::Eye, I1::Base.Slice, I2::Base.Slice) = a +_getindex(a::Eye, I1::Base.Slice, I2::Colon) = a +_getindex(a::Eye, I1::Colon, I2::Base.Slice) = a _view(a::Eye, I1::Colon, I2::Colon) = a +_view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a +_view(a::Eye, I1::Base.Slice, I2::Colon) = a +_view(a::Eye, I1::Colon, I2::Base.Slice) = a # Like `adapt` but preserves `Eye`. _adapt(to, a::Eye) = a diff --git a/src/fillarrays/matrixalgebrakit_truncate.jl b/src/fillarrays/matrixalgebrakit_truncate.jl index e505cf2..1bf6ed2 100644 --- a/src/fillarrays/matrixalgebrakit_truncate.jl +++ b/src/fillarrays/matrixalgebrakit_truncate.jl @@ -20,34 +20,40 @@ const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVe const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} -function MatrixAlgebraKit.findtruncated( - values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy -) - I = findtruncated(Vector(values), strategy.strategy) - prods = collect(cartesianproduct(only(axes(values))))[I] - I_data = unique(map(arg1, prods)) +axis(a) = only(axes(a)) + +# Convert indices determined with a generic call to `findtruncated` to indices +# more suited for a KroneckerVector. +function to_truncated_indices(values::OnesKroneckerVector, I) + prods = cartesianproduct(axis(values))[I] + I_id = only(to_indices(arg1(values), (:,))) + I_data = unique(arg2.(prods)) # Drop truncations that occur within the identity. I_data = filter(I_data) do i - return count(x -> arg1(x) == i, prods) == length(arg1(values)) + return count(x -> arg2(x) == i, prods) == length(arg2(values)) end - return (:) × I_data + return I_id × I_data end -function MatrixAlgebraKit.findtruncated( - values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy -) - I = findtruncated(Vector(values), strategy.strategy) - prods = collect(cartesianproduct(only(axes(values))))[I] - I_data = unique(map(x -> arg2(x), prods)) +function to_truncated_indices(values::KroneckerOnesVector, I) + #I = findtruncated(Vector(values), strategy.strategy) + prods = cartesianproduct(axis(values))[I] + I_data = unique(arg1.(prods)) # Drop truncations that occur within the identity. I_data = filter(I_data) do i - return count(x -> arg2(x) == i, prods) == length(arg2(values)) + return count(x -> arg1(x) == i, prods) == length(arg2(values)) end - return I_data × (:) + I_id = only(to_indices(arg2(values), (:,))) + return I_data × I_id +end +function to_truncated_indices(values::OnesVectorOnesVector, I) + return throw(ArgumentError("Can't truncate Eye ⊗ Eye.")) end + function MatrixAlgebraKit.findtruncated( - values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy + values::KroneckerVector, strategy::KroneckerTruncationStrategy ) - return throw(ArgumentError("Can't truncate Eye ⊗ Eye.")) + I = findtruncated(Vector(values), strategy.strategy) + return to_truncated_indices(values, I) end for f in [:eig_trunc!, :eigh_trunc!] diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 94d601a..96167f3 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -167,13 +167,22 @@ function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where { return a[I′...] end +# Indexing logic. +function Base.to_indices( + a::KroneckerArray, inds, I::Tuple{Union{CartesianPair,CartesianProduct},Vararg} +) + I1 = to_indices(arg1(a), arg1.(inds), arg1.(I)) + I2 = to_indices(arg2(a), arg2.(inds), arg2.(I)) + return I1 .× I2 +end + # Allow customizing for `FillArrays.Eye`. _getindex(a::AbstractArray, I...) = a[I...] -function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N} - return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...) -end -function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N} - return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...) +function Base.getindex( + a::KroneckerArray{<:Any,N}, I::Vararg{Union{CartesianPair,CartesianProduct},N} +) where {N} + I′ = to_indices(a, I) + return _getindex(arg1(a), arg1.(I′)...) ⊗ _getindex(arg2(a), arg2.(I′)...) end # Fix ambigiuity error. Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[] diff --git a/test/test_basics.jl b/test/test_basics.jl index 0fe2ef3..ee1ec6c 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -26,7 +26,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "KroneckerArrays (eltype=$elt)" for elt in elts p = [1, 2] × [3, 4, 5] @test length(p) == 6 - @test collect(p) == [1 × 3, 2 × 3, 1 × 4, 2 × 4, 1 × 5, 2 × 5] + @test collect(p) == [1 × 3, 1 × 4, 1 × 5, 2 × 3, 2 × 4, 2 × 5] r = @constinferred cartesianrange(2, 3) @test r === @@ -39,10 +39,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test first(r) == 1 @test last(r) == 6 @test r[1 × 1] == 1 - @test r[2 × 1] == 2 - @test r[1 × 2] == 3 - @test r[2 × 2] == 4 - @test r[1 × 3] == 5 + @test r[1 × 2] == 2 + @test r[1 × 3] == 3 + @test r[2 × 1] == 4 + @test r[2 × 2] == 5 @test r[2 × 3] == 6 r = @constinferred(cartesianrange(2 × 3, 2:7)) @@ -53,10 +53,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test first(r) == 2 @test last(r) == 7 @test r[1 × 1] == 2 - @test r[2 × 1] == 3 - @test r[1 × 2] == 4 - @test r[2 × 2] == 5 - @test r[1 × 3] == 6 + @test r[1 × 2] == 3 + @test r[1 × 3] == 4 + @test r[2 × 1] == 5 + @test r[2 × 2] == 6 @test r[2 × 3] == 7 # Test high-dimensional materialization. diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index bd88d9d..770e216 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -1,10 +1,10 @@ using Adapt: adapt using BlockArrays: Block, BlockRange, mortar using BlockSparseArrays: - BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype + BlockIndexVector, BlockSparseArray, BlockSparseMatrix, blockrange, blocksparse, blocktype using FillArrays: Eye, SquareEye using JLArrays: JLArray -using KroneckerArrays: KroneckerArray, ⊗, × +using KroneckerArrays: KroneckerArray, ⊗, ×, arg1, arg2 using LinearAlgebra: norm using MatrixAlgebraKit: svd_compact using Test: @test, @test_broken, @testset @@ -48,7 +48,18 @@ arrayts = (Array, JLArray) @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] - @test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + @test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == + a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + + # Blockwise slicing, shows up in truncated block sparse matrix factorizations. + I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1]) + I2 = BlockIndexVector(Block(2), Base.Slice(Base.OneTo(3)) × [1, 3]) + I = [I1, I2] + b = a[I, I] + @test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]] + @test iszero(b[Block(2, 1)]) + @test iszero(b[Block(1, 2)]) + @test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]] # Slicing r = blockrange([2 × 2, 3 × 3]) @@ -159,7 +170,22 @@ end @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] - @test_broken a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + @test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] == + a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)] + + # Blockwise slicing, shows up in truncated block sparse matrix factorizations. + I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1]) + I2 = BlockIndexVector(Block(2), Base.Slice(Base.OneTo(3)) × [1, 3]) + I = [I1, I2] + b = a[I, I] + @test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]] + @test arg1(b[Block(1, 1)]) isa Eye + @test iszero(b[Block(2, 1)]) + @test arg1(b[Block(2, 1)]) isa Eye + @test iszero(b[Block(1, 2)]) + @test arg1(b[Block(1, 2)]) isa Eye + @test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]] + @test arg1(b[Block(2, 2)]) isa Eye # Slicing r = blockrange([2 × 2, 3 × 3])