diff --git a/Project.toml b/Project.toml index 2921538..5c432d6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.18" +version = "0.1.19" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl index 91f579f..1682ad5 100644 --- a/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl +++ b/ext/KroneckerArraysBlockSparseArraysExt/KroneckerArraysBlockSparseArraysExt.jl @@ -1,9 +1,53 @@ module KroneckerArraysBlockSparseArraysExt -using BlockSparseArrays: BlockSparseArrays, blockrange -using KroneckerArrays: CartesianProduct, cartesianrange +using BlockArrays: Block +using BlockSparseArrays: BlockIndexVector, GenericBlockIndex +using KroneckerArrays: CartesianPair, CartesianProduct +function Base.getindex(b::Block, I1::CartesianPair, Irest::CartesianPair...) + return GenericBlockIndex(b, (I1, Irest...)) +end +function Base.getindex(b::Block, I1::CartesianProduct, Irest::CartesianProduct...) + return BlockIndexVector(b, (I1, Irest...)) +end + +using BlockSparseArrays: BlockSparseArrays, BlockUnitRange, blockrange +using KroneckerArrays: CartesianPair, CartesianProduct, ×, cartesianrange + function BlockSparseArrays.blockrange(bs::Vector{<:CartesianProduct}) - return blockrange(map(cartesianrange, bs)) + return blockrange(cartesianrange.(bs)) +end +function BlockSparseArrays.blockrange(bs::Vector{<:CartesianPair{<:Integer,<:Integer}}) + bs′ = map(bs) do b + return Base.OneTo(arg1(b)) × Base.OneTo(arg2(b)) + end + return blockrange(bs′) +end + +using BlockSparseArrays: BlockSparseArrays, infimum +using KroneckerArrays: cartesianproduct, CartesianProductUnitRange +function BlockSparseArrays.infimum(r1::CartesianProductUnitRange, r2::CartesianProductUnitRange) + return cartesianrange(infimum(cartesianproduct.((r1, r2))...)) +end +function BlockSparseArrays.infimum(r1::CartesianProduct, r2::CartesianProduct) + return infimum(arg1(r1), arg1(r2)) × infimum(arg2(r1), arg2(r2)) +end + +using BlockArrays: Block +using KroneckerArrays: cartesianrange +function Base.getindex( + r::BlockUnitRange{<:Integer,<:Vector{<:CartesianProduct}}, I::Block{1,Int64} +) + prod = eachblockaxis(r)[Int(I)] + range = r.r[I] + return cartesianrange(prod, range) +end + +# Fix ambiguity error with BlockArrays.jl. +using BlockArrays: AbstractBlockArray +function Base.similar( + a::AbstractBlockArray, axs::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}} +) + return similar(a, eltype(a), axs) end using BlockArrays: AbstractBlockedUnitRange diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index c77a2f5..6ddd6b9 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -1,4 +1,22 @@ -struct CartesianProduct{A,B} +struct CartesianPair{A,B} + a::A + b::B +end +arguments(a::CartesianPair) = (a.a, a.b) +arguments(a::CartesianPair, n::Int) = arguments(a)[n] + +arg1(a::CartesianPair) = a.a +arg2(a::CartesianPair) = a.b + +×(a, b) = CartesianPair(a, b) + +function Base.show(io::IO, a::CartesianPair) + print(io, a.a, " × ", a.b) + return nothing +end + +struct CartesianProduct{TA,TB,A<:AbstractVector{TA},B<:AbstractVector{TB}} <: + AbstractVector{CartesianPair{TA,TB}} a::A b::B end @@ -13,17 +31,44 @@ function Base.show(io::IO, a::CartesianProduct) return nothing end -×(a, b) = CartesianProduct(a, b) -Base.length(a::CartesianProduct) = length(a.a) * length(a.b) -Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b] +# This is used when printing block sparse arrays with KroneckerArray +# blocks. +# TODO: Investigate if this is needed or if it can be avoided +# by iterating over CartesianProduct axes. +function Base.checkindex(::Type{Bool}, inds::CartesianProduct, i::Int) + return checkindex(Bool, Base.OneTo(length(inds)), i) +end + +×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b) +Base.length(a::CartesianProduct) = length(arg1(a)) * length(arg2(a)) +Base.size(a::CartesianProduct) = (length(a),) +function Base.getindex(a::CartesianProduct, i::CartesianProduct) + return arg1(a)[arg1(i)] × arg2(a)[arg2(i)] +end +function Base.getindex(a::CartesianProduct, i::CartesianPair) + return arg1(a)[arg1(i)] × arg2(a)[arg2(i)] +end +function Base.getindex(a::CartesianProduct, i::Int) + I = Tuple(CartesianIndices((length(arg1(a)), length(arg2(a))))[i]) + return a[I[1] × I[2]] +end + +using Base: promote_shape +function Base.promote_shape( + a::Tuple{Vararg{CartesianProduct}}, b::Tuple{Vararg{CartesianProduct}} +) + return promote_shape(arg1.(a), arg1.(b)) × promote_shape(arg2.(a), arg2.(b)) +end -function Base.iterate(a::CartesianProduct, state...) - x = iterate(Iterators.product(a.a, a.b), state...) - isnothing(x) && return x - next, new_state = x - return ×(next...), new_state +using Base.Broadcast: axistype +function Base.Broadcast.axistype(r1::CartesianProduct, r2::CartesianProduct) + return axistype(arg1(r1), arg1(r2)) × axistype(arg2(r1), arg2(r2)) end +## function Base.to_index(A::KroneckerArray, I::CartesianProduct) +## return I +## end + struct CartesianProductUnitRange{T,P<:CartesianProduct,R<:AbstractUnitRange{T}} <: AbstractUnitRange{T} product::P @@ -38,27 +83,36 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range) arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a)) arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a)) +function Base.show(io::IO, r::CartesianProductUnitRange) + print(io, cartesianproduct(r), ": ", unproduct(r)) + return nothing +end +function Base.show(io::IO, mime::MIME"text/plain", r::CartesianProductUnitRange) + show(io, mime, cartesianproduct(r)) + println(io) + show(io, mime, unproduct(r)) + return nothing +end + function CartesianProductUnitRange(p::CartesianProduct) return CartesianProductUnitRange(p, Base.OneTo(length(p))) end function CartesianProductUnitRange(a, b) return CartesianProductUnitRange(a × b) end -to_range(a::AbstractUnitRange) = a -to_range(i::Integer) = Base.OneTo(i) -cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b)) +to_product_indices(a::AbstractVector) = a +to_product_indices(i::Integer) = Base.OneTo(i) +cartesianrange(a, b) = cartesianrange(to_product_indices(a) × to_product_indices(b)) function cartesianrange(p::CartesianProduct) - p′ = to_range(p.a) × to_range(p.b) + p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) return cartesianrange(p′, Base.OneTo(length(p′))) end function cartesianrange(p::CartesianProduct, range::AbstractUnitRange) - p′ = to_range(p.a) × to_range(p.b) + p′ = to_product_indices(arg1(p)) × to_product_indices(arg2(p)) return CartesianProductUnitRange(p′, range) end -function Base.axes(r::CartesianProductUnitRange) - return (CartesianProductUnitRange(r.product, only(axes(r.range))),) -end +Base.axes(r::CartesianProductUnitRange) = (cartesianrange(cartesianproduct(r)),) using Base.Broadcast: DefaultArrayStyle for f in (:+, :-) @@ -84,3 +138,7 @@ function Base.Broadcast.axistype( range = axistype(unproduct(r1), unproduct(r2)) return cartesianrange(prod, range) end + +function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::CartesianPair) + return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i)) +end diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index f132dbf..ef3ec4d 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -1,13 +1,17 @@ using FillArrays: FillArrays, Zeros function FillArrays.fillsimilar( - a::Zeros{T}, - ax::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, + a::Zeros{T}, ax::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}} ) where {T} return Zeros{T}(arg1.(ax)) ⊗ Zeros{T}(arg2.(ax)) end +# Work around that `Zeros` requires `AbstractUnitRange` axes. +function FillArrays.Zeros{T,N}( + ax::Tuple{CartesianProduct,Vararg{CartesianProduct}} +) where {T,N} + return Zeros{T,N}(cartesianslice.(ax)) +end + using FillArrays: RectDiagonal, OnesVector const RectEye{T,V<:OnesVector{T},Axes} = RectDiagonal{T,V,Axes} @@ -68,6 +72,8 @@ end # Like `copy` but preserves `Eye`. _copy(a::Eye) = a +_getindex(a::Eye, I1::Colon, I2::Colon) = a + using DerivableInterfaces: DerivableInterfaces, zero! function DerivableInterfaces.zero!(a::EyeKronecker) zero!(a.b) diff --git a/src/fillarrays/matrixalgebrakit.jl b/src/fillarrays/matrixalgebrakit.jl index 093760b..8e2def2 100644 --- a/src/fillarrays/matrixalgebrakit.jl +++ b/src/fillarrays/matrixalgebrakit.jl @@ -1,4 +1,4 @@ -function infimum(r1::AbstractRange, r2::AbstractUnitRange) +function infimum(r1::AbstractUnitRange, r2::AbstractUnitRange) Base.require_one_based_indexing(r1, r2) if length(r1) ≤ length(r2) return r1 @@ -6,7 +6,10 @@ function infimum(r1::AbstractRange, r2::AbstractUnitRange) return r2 end end -function supremum(r1::AbstractRange, r2::AbstractUnitRange) +function infimum(r1::CartesianProduct, r2::CartesianProduct) + return infimum(arg1(r1), arg1(r2)) × infimum(arg2(r1), arg2(r2)) +end +function supremum(r1::AbstractUnitRange, r2::AbstractUnitRange) Base.require_one_based_indexing(r1, r2) if length(r1) ≥ length(r2) return r1 @@ -14,6 +17,9 @@ function supremum(r1::AbstractRange, r2::AbstractUnitRange) return r2 end end +function supremum(r1::CartesianProduct, r2::CartesianProduct) + return supremum(arg1(r1), arg1(r2)) × supremum(arg2(r1), arg2(r2)) +end # Allow customization for `Eye`. _diagview(a::Eye) = parent(a) diff --git a/src/fillarrays/matrixalgebrakit_truncate.jl b/src/fillarrays/matrixalgebrakit_truncate.jl index ae50f27..22a19b7 100644 --- a/src/fillarrays/matrixalgebrakit_truncate.jl +++ b/src/fillarrays/matrixalgebrakit_truncate.jl @@ -24,11 +24,11 @@ function MatrixAlgebraKit.findtruncated( values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy ) I = findtruncated(Vector(values), strategy.strategy) - prods = collect(only(axes(values)).product)[I] - I_data = unique(map(x -> x.a, prods)) + prods = only(axes(values))[I] + I_data = unique(arg1.(prods)) # Drop truncations that occur within the identity. I_data = filter(I_data) do i - return count(x -> x.a == i, prods) == length(values.a) + return count(x -> arg1(x) == i, prods) == length(arg1(values)) end return (:) × I_data end @@ -36,11 +36,11 @@ function MatrixAlgebraKit.findtruncated( values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy ) I = findtruncated(Vector(values), strategy.strategy) - prods = collect(only(axes(values)).product)[I] - I_data = unique(map(x -> x.b, prods)) + prods = only(axes(values))[I] + I_data = unique(map(x -> arg2(x), prods)) # Drop truncations that occur within the identity. I_data = filter(I_data) do i - return count(x -> x.b == i, prods) == length(values.b) + return count(x -> arg2(x) == i, prods) == length(arg2(values)) end return I_data × (:) end diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 82b4357..776e968 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -38,6 +38,11 @@ function Base.copyto!(dest::KroneckerArray, src::KroneckerArray) return dest end +using Base: has_offset_axes +function Base.has_offset_axes(a::KroneckerArray) + return has_offset_axes(arg1(a)) || has_offset_axes(arg2(a)) +end + # Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`. function _similar(a::AbstractArray, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}}) return similar(a, elt, axs) @@ -46,43 +51,34 @@ function _similar(arrayt::Type{<:AbstractArray}, axs::Tuple{Vararg{AbstractUnitR return similar(arrayt, axs) end +function Base.similar(a::KroneckerArray, elt::Type) + return similar(a, elt, axes(a)) +end +function Base.similar( + a::AbstractArray, axs::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}} +) + return similar(a, eltype(a), axs) +end function Base.similar( - a::AbstractArray, - elt::Type, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, + a::AbstractArray, elt::Type, axs::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}} ) - return _similar(a, elt, map(ax -> ax.product.a, axs)) ⊗ - _similar(a, elt, map(ax -> ax.product.b, axs)) + return _similar(a, elt, arg1.(axs)) ⊗ _similar(a, elt, arg2.(axs)) end function Base.similar( - a::KroneckerArray, - elt::Type, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, + a::KroneckerArray, elt::Type, axs::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}} ) - return _similar(a.a, elt, map(ax -> ax.product.a, axs)) ⊗ - _similar(a.b, elt, map(ax -> ax.product.b, axs)) + return _similar(arg1(a), elt, arg1.(axs)) ⊗ _similar(arg2(a), elt, arg2.(axs)) end function Base.similar( - arrayt::Type{<:AbstractArray}, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, + arrayt::Type{<:AbstractArray}, axs::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}} ) - return _similar(arrayt, map(ax -> ax.product.a, axs)) ⊗ - _similar(arrayt, map(ax -> ax.product.b, axs)) + return _similar(arrayt, arg1.(axs)) ⊗ _similar(arrayt, arg2.(axs)) end function Base.similar( arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, - axs::Tuple{ - CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} - }, + axs::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}}, ) where {A,B} - return _similar(A, map(ax -> ax.product.a, axs)) ⊗ - _similar(B, map(ax -> ax.product.b, axs)) + return _similar(A, arg1.(axs)) ⊗ _similar(B, arg2.(axs)) end function Base.similar( ::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, sz::Tuple{Int,Vararg{Int}} @@ -123,13 +119,13 @@ function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N} return convert(Array{T,N}, collect(a)) end -Base.size(a::KroneckerArray) = ntuple(dim -> size(a.a, dim) * size(a.b, dim), ndims(a)) +function Base.size(a::KroneckerArray) + ntuple(dim -> size(arg1(a), dim) * size(arg2(a), dim), ndims(a)) +end function Base.axes(a::KroneckerArray) return ntuple(ndims(a)) do dim - return CartesianProductUnitRange( - axes(a.a, dim) × axes(a.b, dim), Base.OneTo(size(a, dim)) - ) + return cartesianrange(axes(arg1(a), dim) × axes(arg2(a), dim)) end end @@ -160,33 +156,22 @@ function Base.getindex(a::KroneckerArray, i::Integer) return a[CartesianIndices(a)[i]] end -# TODO: Use this logic from KroneckerProducts.jl for cartesian indexing -# in the n-dimensional case and use it to replace the matrix and vector cases: -# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66 -function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N} - return error("Not implemented.") -end - using GPUArraysCore: GPUArraysCore -function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer) +function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N} GPUArraysCore.assertscalar("getindex") - # Code logic from Kronecker.jl: - # https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105 - k, l = size(a.b) - return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1] + I′ = ntuple(Val(N)) do dim + return axes(a, dim)[I[dim]] + end + return a[I′...] end -function Base.getindex(a::KroneckerVector, i::Integer) - GPUArraysCore.assertscalar("getindex") - k = length(a.b) - return a.a[cld(i, k)] * a.b[(i - 1) % k + 1] +# Allow customizing for `FillArrays.Eye`. +_getindex(a::AbstractArray, I...) = a[I...] +function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N} + return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...) end - -## function Base.getindex(a::KroneckerVector, i::CartesianProduct) -## return a.a[i.a] ⊗ a.b[i.b] -## end function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N} - return a.a[map(Base.Fix2(getfield, :a), I)...] ⊗ a.b[map(Base.Fix2(getfield, :b), I)...] + return _getindex(arg1(a), arg1.(I)...) ⊗ _getindex(arg2(a), arg2.(I)...) end # Fix ambigiuity error. Base.getindex(a::KroneckerArray{<:Any,0}) = a.a[] * a.b[] diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 5727e26..683b58f 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -1,7 +1,9 @@ +using BlockSparseArrays: BlockSparseArrays using KroneckerArrays: KroneckerArrays using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - Aqua.test_all(KroneckerArrays) + # TODO: Investigate and fix ambiguities. + Aqua.test_all(KroneckerArrays; ambiguities=false) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 5684936..86c35c7 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -7,7 +7,6 @@ using KroneckerArrays: KroneckerArrays, KroneckerArray, KroneckerStyle, - CartesianProductUnitRange, ⊗, ×, cartesianproduct, @@ -25,19 +24,15 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test length(p) == 6 @test collect(p) == [1 × 3, 2 × 3, 1 × 4, 2 × 4, 1 × 5, 2 × 5] - r = @constinferred cartesianrange(2, 3) - @test r === - @constinferred(cartesianrange(2 × 3)) === - @constinferred(cartesianrange(Base.OneTo(2), Base.OneTo(3))) === - @constinferred(cartesianrange(Base.OneTo(2) × Base.OneTo(3))) + r = @constinferred cartesianrange(Base.OneTo(2), Base.OneTo(3)) + @test r === @constinferred(cartesianrange(Base.OneTo(2) × Base.OneTo(3))) @test @constinferred(cartesianproduct(r)) === Base.OneTo(2) × Base.OneTo(3) @test unproduct(r) === Base.OneTo(6) @test length(r) == 6 @test first(r) == 1 @test last(r) == 6 - r = @constinferred(cartesianrange(2 × 3, 2:7)) - @test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7) + r = @constinferred(cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7)) @test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3) @test unproduct(r) === 2:7 @test length(r) == 6