Skip to content

Commit b9010bb

Browse files
authored
More convenient cartesianrange constructor (#10)
1 parent 255ab88 commit b9010bb

File tree

3 files changed

+51
-2
lines changed

3 files changed

+51
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.7"
4+
version = "0.1.8"
55

66
[deps]
77
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"

src/KroneckerArrays.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,27 @@ end
3535
Base.first(r::CartesianProductUnitRange) = first(r.range)
3636
Base.last(r::CartesianProductUnitRange) = last(r.range)
3737

38+
cartesianproduct(r::CartesianProductUnitRange) = getfield(r, :product)
39+
unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
40+
41+
function CartesianProductUnitRange(p::CartesianProduct)
42+
return CartesianProductUnitRange(p, Base.OneTo(length(p)))
43+
end
44+
function CartesianProductUnitRange(a, b)
45+
return CartesianProductUnitRange(a × b)
46+
end
47+
to_range(a::AbstractUnitRange) = a
48+
to_range(i::Integer) = Base.OneTo(i)
49+
cartesianrange(a, b) = cartesianrange(to_range(a) × to_range(b))
50+
function cartesianrange(p::CartesianProduct)
51+
p′ = to_range(p.a) × to_range(p.b)
52+
return cartesianrange(p′, Base.OneTo(length(p′)))
53+
end
54+
function cartesianrange(p::CartesianProduct, range::AbstractUnitRange)
55+
p′ = to_range(p.a) × to_range(p.b)
56+
return CartesianProductUnitRange(p′, range)
57+
end
58+
3859
function Base.axes(r::CartesianProductUnitRange)
3960
return (CartesianProductUnitRange(r.product, only(axes(r.range))),)
4061
end

test/test_basics.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
using FillArrays: Eye
2-
using KroneckerArrays: KroneckerArrays, , ×, diagonal, kron_nd
2+
using KroneckerArrays:
3+
KroneckerArrays,
4+
CartesianProductUnitRange,
5+
,
6+
×,
7+
cartesianproduct,
8+
cartesianrange,
9+
diagonal,
10+
kron_nd,
11+
unproduct
312
using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, pinv, qr, svd, svdvals, tr
413
using StableRNGs: StableRNG
514
using Test: @test, @test_broken, @test_throws, @testset
@@ -10,6 +19,25 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
1019
@test length(p) == 6
1120
@test collect(p) == [1 × 3, 2 × 3, 1 × 4, 2 × 4, 1 × 5, 2 × 5]
1221

22+
r = cartesianrange(2, 3)
23+
@test r ===
24+
cartesianrange(2 × 3) ===
25+
cartesianrange(Base.OneTo(2), Base.OneTo(3)) ===
26+
cartesianrange(Base.OneTo(2) × Base.OneTo(3))
27+
@test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3)
28+
@test unproduct(r) === Base.OneTo(6)
29+
@test length(r) == 6
30+
@test first(r) == 1
31+
@test last(r) == 6
32+
33+
r = cartesianrange(2 × 3, 2:7)
34+
@test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7)
35+
@test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3)
36+
@test unproduct(r) === 2:7
37+
@test length(r) == 6
38+
@test first(r) == 2
39+
@test last(r) == 7
40+
1341
a = randn(elt, 2, 2) randn(elt, 3, 3)
1442
b = randn(elt, 2, 2) randn(elt, 3, 3)
1543
c = a.a b.b

0 commit comments

Comments
 (0)