diff --git a/Project.toml b/Project.toml index 38353b5..c091764 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.7" +version = "0.1.8" [deps] DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 0bce6f8..92ca131 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -35,6 +35,27 @@ end Base.first(r::CartesianProductUnitRange) = first(r.range) Base.last(r::CartesianProductUnitRange) = last(r.range) +cartesianproduct(r::CartesianProductUnitRange) = getfield(r, :product) +unproduct(r::CartesianProductUnitRange) = getfield(r, :range) + +function CartesianProductUnitRange(p::CartesianProduct) + return CartesianProductUnitRange(p, Base.OneTo(length(p))) +end +function CartesianProductUnitRange(a, b) + return CartesianProductUnitRange(a × b) +end +to_range(a::AbstractUnitRange) = a +to_range(i::Integer) = Base.OneTo(i) +cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b)) +function cartesianrange(p::CartesianProduct) + p′ = to_range(p.a) × to_range(p.b) + return cartesianrange(p′, Base.OneTo(length(p′))) +end +function cartesianrange(p::CartesianProduct, range::AbstractUnitRange) + p′ = to_range(p.a) × to_range(p.b) + return CartesianProductUnitRange(p′, range) +end + function Base.axes(r::CartesianProductUnitRange) return (CartesianProductUnitRange(r.product, only(axes(r.range))),) end diff --git a/test/test_basics.jl b/test/test_basics.jl index c2aaedc..76ae661 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,5 +1,14 @@ using FillArrays: Eye -using KroneckerArrays: KroneckerArrays, ⊗, ×, diagonal, kron_nd +using KroneckerArrays: + KroneckerArrays, + CartesianProductUnitRange, + ⊗, + ×, + cartesianproduct, + cartesianrange, + diagonal, + kron_nd, + unproduct using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, pinv, qr, svd, svdvals, tr using StableRNGs: StableRNG using Test: @test, @test_broken, @test_throws, @testset @@ -10,6 +19,25 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64) @test length(p) == 6 @test collect(p) == [1 × 3, 2 × 3, 1 × 4, 2 × 4, 1 × 5, 2 × 5] + r = cartesianrange(2, 3) + @test r === + cartesianrange(2 × 3) === + cartesianrange(Base.OneTo(2), Base.OneTo(3)) === + cartesianrange(Base.OneTo(2) × Base.OneTo(3)) + @test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3) + @test unproduct(r) === Base.OneTo(6) + @test length(r) == 6 + @test first(r) == 1 + @test last(r) == 6 + + r = cartesianrange(2 × 3, 2:7) + @test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7) + @test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3) + @test unproduct(r) === 2:7 + @test length(r) == 6 + @test first(r) == 2 + @test last(r) == 7 + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c = a.a ⊗ b.b