Skip to content

Commit 9a8951f

Browse files
committed
Fix axes of CartesianProductVector
1 parent 7c007b5 commit 9a8951f

File tree

3 files changed

+32
-16
lines changed

3 files changed

+32
-16
lines changed

src/cartesianproduct.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ unproduct(r::CartesianProductVector) = getfield(r, :values)
6262
Base.length(a::CartesianProductVector) = length(unproduct(a))
6363
Base.size(a::CartesianProductVector) = (length(a),)
6464
function Base.axes(r::CartesianProductVector)
65-
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
65+
prod = cartesianproduct(r)
66+
prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod)))
67+
return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),)
6668
end
6769
function Base.copy(a::CartesianProductVector)
6870
return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a)))

test/test_basics.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using KroneckerArrays:
99
KroneckerArray,
1010
KroneckerStyle,
1111
CartesianProductUnitRange,
12+
CartesianProductVector,
1213
,
1314
×,
1415
arg1,
@@ -45,6 +46,14 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
4546
@test r[2 × 2] == 5
4647
@test r[2 × 3] == 6
4748

49+
# CartesianProductUnitRange axes
50+
r = cartesianrange((2:3) × (3:4), 2:5)
51+
@test axes(r) (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),)
52+
53+
# CartesianProductVector axes
54+
r = CartesianProductVector(([2, 4]) × ([3, 5]), [3, 5, 7, 9])
55+
@test axes(r) (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),)
56+
4857
r = @constinferred(cartesianrange(2 × 3, 2:7))
4958
@test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7)
5059
@test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3)

test/test_blocksparsearrays.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,12 @@ arrayts = (Array, JLArray)
130130
@test_broken svd_compact(a)
131131
end
132132

133+
b = a[Block.(1:2), Block(2)]
134+
@test b[Block(1)] == a[Block(1, 2)]
135+
@test b[Block(2)] == a[Block(2, 2)]
136+
133137
# Broken operations
134138
@test_broken exp(a)
135-
@test_broken a[Block.(1:2), Block(2)]
136139
end
137140

138141
@testset "BlockSparseArraysExt, EyeKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in
@@ -174,19 +177,19 @@ end
174177
@test a[Block(2, 2)[(1:2) × (2:3), (:) × (2:3)]] ==
175178
a[Block(2, 2)][(1:2) × (2:3), (:) × (2:3)]
176179

177-
## # Blockwise slicing, shows up in truncated block sparse matrix factorizations.
178-
## I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1])
179-
## I2 = BlockIndexVector(Block(2), Base.Slice(Base.OneTo(3)) × [1, 3])
180-
## I = [I1, I2]
181-
## b = a[I, I]
182-
## @test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]]
183-
## @test arg1(b[Block(1, 1)]) isa Eye
184-
## @test iszero(b[Block(2, 1)])
185-
## @test arg1(b[Block(2, 1)]) isa Eye
186-
## @test iszero(b[Block(1, 2)])
187-
## @test arg1(b[Block(1, 2)]) isa Eye
188-
## @test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]]
189-
## @test arg1(b[Block(2, 2)]) isa Eye
180+
# Blockwise slicing, shows up in truncated block sparse matrix factorizations.
181+
I1 = BlockIndexVector(Block(1), Base.Slice(Base.OneTo(2)) × [1])
182+
I2 = BlockIndexVector(Block(2), Base.Slice(Base.OneTo(3)) × [1, 3])
183+
I = [I1, I2]
184+
b = a[I, I]
185+
@test b[Block(1, 1)] == a[Block(1, 1)[(1:2) × [1], (1:2) × [1]]]
186+
@test arg1(b[Block(1, 1)]) isa Eye
187+
@test iszero(b[Block(2, 1)])
188+
@test arg1(b[Block(2, 1)]) isa Eye
189+
@test iszero(b[Block(1, 2)])
190+
@test arg1(b[Block(1, 2)]) isa Eye
191+
@test b[Block(2, 2)] == a[Block(2, 2)[(1:3) × [1, 3], (1:3) × [1, 3]]]
192+
@test arg1(b[Block(2, 2)]) isa Eye
190193

191194
# Slicing
192195
r = blockrange([2 × 2, 3 × 3])
@@ -272,7 +275,9 @@ end
272275
end
273276

274277
# Broken operations
275-
@test_broken a[Block.(1:2), Block(2)]
278+
b = a[Block.(1:2), Block(2)]
279+
@test b[Block(1)] == a[Block(1, 2)]
280+
@test b[Block(2)] == a[Block(2, 2)]
276281

277282
# svd_trunc
278283
dev = adapt(arrayt)

0 commit comments

Comments
 (0)