diff --git a/Project.toml b/Project.toml index 2aa9b11..3a2f50b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.26" +version = "0.1.27" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -16,14 +16,16 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" [weakdeps] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" +TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" [extensions] KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"] +KroneckerArraysTensorProductsExt = "TensorProducts" [compat] Adapt = "4.3" BlockArrays = "1.6" -BlockSparseArrays = "0.8.1" +BlockSparseArrays = "0.9" DerivableInterfaces = "0.5" DiagonalArrays = "0.3.5" FillArrays = "1.13" @@ -31,4 +33,5 @@ GPUArraysCore = "0.2" LinearAlgebra = "1.10" MapBroadcast = "0.1.9" MatrixAlgebraKit = "0.2" +TensorProducts = "0.1.7" julia = "1.10" diff --git a/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl b/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl new file mode 100644 index 0000000..f45cc37 --- /dev/null +++ b/ext/KroneckerArraysTensorProductsExt/KroneckerArraysTensorProductsExt.jl @@ -0,0 +1,11 @@ +module KroneckerArraysTensorProductsExt + +using KroneckerArrays: CartesianProductOneTo, ×, arg1, arg2, cartesianrange, unproduct +using TensorProducts: TensorProducts, tensor_product +function TensorProducts.tensor_product(a1::CartesianProductOneTo, a2::CartesianProductOneTo) + prod = tensor_product(arg1(a1), arg1(a2)) × tensor_product(arg2(a1), arg2(a2)) + range = tensor_product(unproduct(a1), unproduct(a2)) + return cartesianrange(prod, range) +end + +end diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index be1c4fa..06d54f3 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -62,7 +62,9 @@ unproduct(r::CartesianProductVector) = getfield(r, :values) Base.length(a::CartesianProductVector) = length(unproduct(a)) Base.size(a::CartesianProductVector) = (length(a),) function Base.axes(r::CartesianProductVector) - return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),) + prod = cartesianproduct(r) + prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod))) + return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),) end function Base.copy(a::CartesianProductVector) return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a))) diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index f943ce0..2298039 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -41,24 +41,11 @@ function _convert(::Type{AbstractMatrix{T}}, a::RectDiagonal) where {T} RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a)) end -# Like `similar` but preserves `Eye`. -function _similar(a::AbstractArray, elt::Type, ax::Tuple) - return similar(a, elt, ax) +# Like `similar` but preserves `Eye`, `Ones`, etc. +using FillArrays: Ones +function _similar(arrayt::Type{<:Ones}, axs::Tuple) + return Ones{eltype(arrayt)}(axs) end -function _similar(A::Type{<:AbstractArray}, ax::Tuple) - return similar(A, ax) -end -function _similar(a::AbstractArray, ax::Tuple) - return _similar(a, eltype(a), ax) -end -function _similar(a::AbstractArray, elt::Type) - return _similar(a, elt, axes(a)) -end -function _similar(a::AbstractArray) - return _similar(a, eltype(a), axes(a)) -end - -# Like `similar` but preserves `Eye`. function _similar(a::Eye, elt::Type, axs::NTuple{2,AbstractUnitRange}) return Eye{elt}(axs) end @@ -77,19 +64,6 @@ end # Like `copy` but preserves `Eye`. _copy(a::Eye) = a -using DerivableInterfaces: DerivableInterfaces, zero! -function DerivableInterfaces.zero!(a::EyeKronecker) - zero!(a.b) - return a -end -function DerivableInterfaces.zero!(a::KroneckerEye) - zero!(a.a) - return a -end -function DerivableInterfaces.zero!(a::EyeEye) - return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`.")) -end - using Base.Broadcast: AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 0473b1a..6b492ef 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -54,10 +54,19 @@ function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where end # Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`. -function _similar(a::AbstractArray, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}}) +function _similar(a::AbstractArray, elt::Type, axs::Tuple) return similar(a, elt, axs) end -function _similar(arrayt::Type{<:AbstractArray}, axs::Tuple{Vararg{AbstractUnitRange}}) +function _similar(a::AbstractArray, ax::Tuple) + return _similar(a, eltype(a), ax) +end +function _similar(a::AbstractArray, elt::Type) + return _similar(a, elt, axes(a)) +end +function _similar(a::AbstractArray) + return _similar(a, eltype(a), axes(a)) +end +function _similar(arrayt::Type{<:AbstractArray}, axs::Tuple) return similar(arrayt, axs) end @@ -130,6 +139,16 @@ Base.collect(a::KroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a))) Base.zero(a::KroneckerArray) = zero(arg1(a)) ⊗ zero(arg2(a)) +using DerivableInterfaces: DerivableInterfaces, zero! +function DerivableInterfaces.zero!(a::KroneckerArray) + ismut1 = ismutable(arg1(a)) + ismut2 = ismutable(arg2(a)) + (ismut1 || ismut2) || throw(ArgumentError("Can't zero out immutable KroneckerArray.")) + ismut1 && zero!(arg1(a)) + ismut2 && zero!(arg2(a)) + return a +end + function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N} return convert(Array{T,N}, collect(a)) end @@ -372,13 +391,15 @@ _eltype(x) = eltype(x) _eltype(x::Broadcasted) = Base.promote_op(x.f, _eltype.(x.args)...) using Base.Broadcast: broadcasted -struct KroneckerBroadcasted{A<:Broadcasted,B<:Broadcasted} +struct KroneckerBroadcasted{A,B} a::A b::B end arg1(a::KroneckerBroadcasted) = a.a arg2(a::KroneckerBroadcasted) = a.b ⊗(a::Broadcasted, b::Broadcasted) = KroneckerBroadcasted(a, b) +⊗(a::Broadcasted, b) = KroneckerBroadcasted(a, b) +⊗(a, b::Broadcasted) = KroneckerBroadcasted(a, b) Broadcast.materialize(a::KroneckerBroadcasted) = copy(a) Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a) Broadcast.broadcastable(a::KroneckerBroadcasted) = a diff --git a/src/linearalgebra.jl b/src/linearalgebra.jl index 95bed45..aaf3e06 100644 --- a/src/linearalgebra.jl +++ b/src/linearalgebra.jl @@ -179,10 +179,3 @@ function LinearAlgebra.lq(a::KroneckerArray) Fb = lq(a.b) return KroneckerLQ(Fa.L ⊗ Fb.L, Fa.Q ⊗ Fb.Q) end - -using DerivableInterfaces: DerivableInterfaces, zero! -function DerivableInterfaces.zero!(a::KroneckerArray) - zero!(a.a) - zero!(a.b) - return a -end diff --git a/test/Project.toml b/test/Project.toml index 5fb06cf..f649d4b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" +TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" @@ -21,7 +22,7 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Adapt = "4" Aqua = "0.8" BlockArrays = "1.6" -BlockSparseArrays = "0.8.1" +BlockSparseArrays = "0.9" DerivableInterfaces = "0.5" DiagonalArrays = "0.3.7" FillArrays = "1" @@ -33,5 +34,6 @@ MatrixAlgebraKit = "0.2" SafeTestsets = "0.1" StableRNGs = "1.0" Suppressor = "0.2" +TensorProducts = "0.1.7" Test = "1.10" TestExtras = "0.3" diff --git a/test/test_basics.jl b/test/test_basics.jl index ee1ec6c..e17f8c1 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -9,6 +9,7 @@ using KroneckerArrays: KroneckerArray, KroneckerStyle, CartesianProductUnitRange, + CartesianProductVector, ⊗, ×, arg1, @@ -45,6 +46,14 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test r[2 × 2] == 5 @test r[2 × 3] == 6 + # CartesianProductUnitRange axes + r = cartesianrange((2:3) × (3:4), 2:5) + @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) + + # CartesianProductVector axes + r = CartesianProductVector(([2, 4]) × ([3, 5]), [3, 5, 7, 9]) + @test axes(r) ≡ (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),) + r = @constinferred(cartesianrange(2 × 3, 2:7)) @test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7) @test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3) diff --git a/test/test_blocksparsearrays.jl b/test/test_blocksparsearrays.jl index e971edd..1dcf58a 100644 --- a/test/test_blocksparsearrays.jl +++ b/test/test_blocksparsearrays.jl @@ -23,7 +23,7 @@ arrayts = (Array, JLArray) Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), ) - a = dev(blocksparse(d, r, r)) + a = dev(blocksparse(d, (r, r))) @test sprint(show, a) isa String @test sprint(show, MIME("text/plain"), a) isa String @test blocktype(a) === valtype(d) @@ -45,7 +45,7 @@ arrayts = (Array, JLArray) Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), ) - a = dev(blocksparse(d, r, r)) + a = dev(blocksparse(d, (r, r))) @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] @@ -68,7 +68,7 @@ arrayts = (Array, JLArray) Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), ) - a = dev(blocksparse(d, r, r)) + a = dev(blocksparse(d, (r, r))) i1 = Block(1)[(1:2) × (1:2)] i2 = Block(2)[(2:3) × (2:3)] I = mortar([i1, i2]) @@ -83,7 +83,7 @@ arrayts = (Array, JLArray) Block(1, 1) => dev(randn(elt, 2, 2) ⊗ randn(elt, 2, 2)), Block(2, 2) => dev(randn(elt, 3, 3) ⊗ randn(elt, 3, 3)), ) - a = dev(blocksparse(d, r, r)) + a = dev(blocksparse(d, (r, r))) i1 = Block(1)[(1:2) × (1:2)] i2 = Block(2)[(2:3) × (2:3)] I = [i1, i2] @@ -130,9 +130,12 @@ arrayts = (Array, JLArray) @test_broken svd_compact(a) end + b = a[Block.(1:2), Block(2)] + @test b[Block(1)] == a[Block(1, 2)] + @test b[Block(2)] == a[Block(2, 2)] + # Broken operations @test_broken exp(a) - @test_broken a[Block.(1:2), Block(2)] end @testset "BlockSparseArraysExt, EyeKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in @@ -145,7 +148,7 @@ end Block(1, 1) => Eye{elt}(2, 2) ⊗ dev(randn(elt, 2, 2)), Block(2, 2) => Eye{elt}(3, 3) ⊗ dev(randn(elt, 3, 3)), ) - a = @constinferred dev(blocksparse(d, r, r)) + a = @constinferred dev(blocksparse(d, (r, r))) @test sprint(show, a) == sprint(show, Array(a)) @test sprint(show, MIME("text/plain"), a) isa String @test @constinferred(blocktype(a)) === valtype(d) @@ -167,7 +170,7 @@ end Block(1, 1) => dev(Eye{elt}(2, 2) ⊗ randn(elt, 2, 2)), Block(2, 2) => dev(Eye{elt}(3, 3) ⊗ randn(elt, 3, 3)), ) - a = dev(blocksparse(d, r, r)) + a = dev(blocksparse(d, (r, r))) @test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] == a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)] @test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)] @@ -194,7 +197,7 @@ end Block(1, 1) => dev(Eye{elt}(2, 2) ⊗ randn(elt, 2, 2)), Block(2, 2) => dev(Eye{elt}(3, 3) ⊗ randn(elt, 3, 3)), ) - a = dev(blocksparse(d, r, r)) + a = dev(blocksparse(d, (r, r))) i1 = Block(1)[(1:2) × (1:2)] i2 = Block(2)[(2:3) × (2:3)] I = mortar([i1, i2]) @@ -209,7 +212,7 @@ end Block(1, 1) => dev(Eye{elt}(2, 2) ⊗ randn(elt, 2, 2)), Block(2, 2) => dev(Eye{elt}(3, 3) ⊗ randn(elt, 3, 3)), ) - a = dev(blocksparse(d, r, r)) + a = dev(blocksparse(d, (r, r))) i1 = Block(1)[(1:2) × (1:2)] i2 = Block(2)[(2:3) × (2:3)] I = [i1, i2] @@ -272,7 +275,9 @@ end end # Broken operations - @test_broken a[Block.(1:2), Block(2)] + b = a[Block.(1:2), Block(2)] + @test b[Block(1)] == a[Block(1, 2)] + @test b[Block(2)] == a[Block(2, 2)] # svd_trunc dev = adapt(arrayt) @@ -282,7 +287,7 @@ end Block(1, 1) => Eye{elt}(2, 2) ⊗ randn(rng, elt, 2, 2), Block(2, 2) => Eye{elt}(3, 3) ⊗ randn(rng, elt, 3, 3), ) - a = @constinferred dev(blocksparse(d, r, r)) + a = @constinferred dev(blocksparse(d, (r, r))) if arrayt === Array u, s, v = svd_trunc(a; trunc=(; maxrank=6)) u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=5)) @@ -293,10 +298,10 @@ end @testset "Block deficient" begin da = Dict(Block(1, 1) => Eye{elt}(2, 2) ⊗ dev(randn(elt, 2, 2))) - a = @constinferred dev(blocksparse(da, r, r)) + a = @constinferred dev(blocksparse(da, (r, r))) db = Dict(Block(2, 2) => Eye{elt}(3, 3) ⊗ dev(randn(elt, 3, 3))) - b = @constinferred dev(blocksparse(db, r, r)) + b = @constinferred dev(blocksparse(db, (r, r))) @test Array(a + b) ≈ Array(a) + Array(b) @test Array(2a) ≈ 2Array(a) diff --git a/test/test_tensorproducts.jl b/test/test_tensorproducts.jl new file mode 100644 index 0000000..3fa3c79 --- /dev/null +++ b/test/test_tensorproducts.jl @@ -0,0 +1,13 @@ +using KroneckerArrays: ×, arg1, arg2, cartesianrange, unproduct +using TensorProducts: tensor_product +using Test: @test, @testset + +@testset "KroneckerArraysTensorProductsExt" begin + r1 = cartesianrange(2, 3) + r2 = cartesianrange(4, 5) + r = tensor_product(r1, r2) + @test r ≡ cartesianrange(8, 15) + @test arg1(r) ≡ Base.OneTo(8) + @test arg2(r) ≡ Base.OneTo(15) + @test unproduct(r) ≡ Base.OneTo(120) +end