Skip to content

Commit 8840ab7

Browse files
authored
[BlockSparseArrays] Fix some bugs involving BlockSparseArrays with dual axes (#1488)
* [BlockSparseArrays] Fix some bugs involving BlockSparseArrays with dual axes * [NDTensors] Bump to v0.3.23
1 parent ff92b6f commit 8840ab7

File tree

2 files changed

+53
-4
lines changed

2 files changed

+53
-4
lines changed

ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
@eval module $(gensym())
22
using Compat: Returns
33
using Test: @test, @testset, @test_broken
4-
using BlockArrays: Block, blocksize
4+
using BlockArrays: Block, blockedrange, blocksize
55
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
66
using NDTensors.GradedAxes:
77
GradedAxes, GradedUnitRange, UnitRangeDual, blocklabels, dual, gradedrange
@@ -73,6 +73,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
7373
# be the real test.
7474
for ax in axes(m)
7575
@test ax isa GradedUnitRange
76+
# TODO: Current `fusedims` doesn't merge
77+
# common sectors, need to fix.
7678
@test_broken blocklabels(ax) == [U1(0), U1(1), U1(2)]
7779
@test blocklabels(ax) == [U1(0), U1(1), U1(1), U1(2)]
7880
end
@@ -94,8 +96,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
9496
@testset "dual axes" begin
9597
r = gradedrange([U1(0) => 2, U1(1) => 2])
9698
a = BlockSparseArray{elt}(dual(r), r)
97-
a[Block(1, 1)] = randn(elt, size(a[Block(1, 1)]))
98-
a[Block(2, 2)] = randn(elt, size(a[Block(2, 2)]))
99+
@views for b in [Block(1, 1), Block(2, 2)]
100+
a[b] = randn(elt, size(a[b]))
101+
end
102+
# TODO: Define and use `isdual` here.
103+
@test axes(a, 1) isa UnitRangeDual
104+
@test axes(a, 2) isa GradedUnitRange
105+
@test !(axes(a, 2) isa UnitRangeDual)
99106
a_dense = Array(a)
100107
@test eachindex(a) == CartesianIndices(size(a))
101108
for I in eachindex(a)
@@ -104,8 +111,50 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
104111
@test axes(a') == dual.(reverse(axes(a)))
105112
# TODO: Define and use `isdual` here.
106113
@test axes(a', 1) isa UnitRangeDual
114+
@test axes(a', 2) isa GradedUnitRange
107115
@test !(axes(a', 2) isa UnitRangeDual)
108116
@test isnothing(show(devnull, MIME("text/plain"), a))
117+
118+
# Check preserving dual in tensor algebra.
119+
for b in (a + a, 2 * a, 3 * a - a)
120+
@test Array(b) 2 * Array(a)
121+
# TODO: Define and use `isdual` here.
122+
@test axes(b, 1) isa UnitRangeDual
123+
@test axes(b, 2) isa GradedUnitRange
124+
@test !(axes(b, 2) isa UnitRangeDual)
125+
end
126+
127+
@test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)])))
128+
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
129+
130+
# Test case when all axes are dual.
131+
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
132+
a = BlockSparseArray{elt}(dual(r), dual(r))
133+
@views for i in [Block(1, 1), Block(2, 2)]
134+
a[i] = randn(elt, size(a[i]))
135+
end
136+
b = 2 * a
137+
@test block_nstored(b) == 2
138+
@test Array(b) == 2 * Array(a)
139+
for ax in axes(b)
140+
@test ax isa UnitRangeDual
141+
end
142+
end
143+
144+
# Test case when all axes are dual
145+
# from taking the adjoint.
146+
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
147+
a = BlockSparseArray{elt}(r, r)
148+
@views for i in [Block(1, 1), Block(2, 2)]
149+
a[i] = randn(elt, size(a[i]))
150+
end
151+
b = 2 * a'
152+
@test block_nstored(b) == 2
153+
@test Array(b) == 2 * Array(a)'
154+
for ax in axes(b)
155+
@test ax isa UnitRangeDual
156+
end
157+
end
109158
end
110159
@testset "Matrix multiplication" begin
111160
r = gradedrange([U1(0) => 2, U1(1) => 3])

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ function cartesianindices(axes::Tuple, b::Block)
129129
end
130130

131131
# Get the range within a block.
132-
function blockindexrange(axis::AbstractUnitRange, r::UnitRange)
132+
function blockindexrange(axis::AbstractUnitRange, r::AbstractUnitRange)
133133
bi1 = findblockindex(axis, first(r))
134134
bi2 = findblockindex(axis, last(r))
135135
b = block(bi1)

0 commit comments

Comments
 (0)