From b3b5ce636d28d54366eb01f87fdd1c315a009ea7 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 22 Jun 2025 20:46:37 -0400 Subject: [PATCH] Index CartesionProductUnitRange with CartesianPair --- Project.toml | 2 +- src/cartesianproduct.jl | 6 ++++++ test/test_basics.jl | 12 ++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 43ecb60..1177da1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.19" +version = "0.1.20" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index 88806da..e3eb2fe 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -94,6 +94,12 @@ function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::Carte return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i)) end +# Reverse map from CartesianPair to linear index in the range. +function Base.getindex(inds::CartesianProductUnitRange, i::CartesianPair) + i′ = (findfirst(==(arg1(i)), arg1(inds)), findfirst(==(arg2(i)), arg2(inds))) + return inds[LinearIndices((length(arg1(inds)), length(arg2(inds))))[i′...]] +end + using Base.Broadcast: DefaultArrayStyle for f in (:+, :-) @eval begin diff --git a/test/test_basics.jl b/test/test_basics.jl index db58023..c8fa835 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -36,6 +36,12 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test length(r) == 6 @test first(r) == 1 @test last(r) == 6 + @test r[1 × 1] == 1 + @test r[2 × 1] == 2 + @test r[1 × 2] == 3 + @test r[2 × 2] == 4 + @test r[1 × 3] == 5 + @test r[2 × 3] == 6 r = @constinferred(cartesianrange(2 × 3, 2:7)) @test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7) @@ -44,6 +50,12 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test length(r) == 6 @test first(r) == 2 @test last(r) == 7 + @test r[1 × 1] == 2 + @test r[2 × 1] == 3 + @test r[1 × 2] == 4 + @test r[2 × 2] == 5 + @test r[1 × 3] == 6 + @test r[2 × 3] == 7 # Test high-dimensional materialization. a = randn(elt, 2, 2, 2) ⊗ randn(elt, 2, 2, 2)