Skip to content

Commit 83c179a

Browse files
committed
Merge branch 'main' into blockedaxes
2 parents 024fd3f + 42ec520 commit 83c179a

File tree

5 files changed

+47
-4
lines changed

5 files changed

+47
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.20"
4+
version = "0.2.22"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/BlockSparseArraysGradedUnitRangesExt/BlockSparseArraysGradedUnitRangesExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ function Base.similar(
5050
end
5151

5252
# Fix ambiguity error with `BlockSparseArrays.jl`.
53+
function Base.similar(
54+
a::AnyAbstractBlockSparseArray,
55+
elt::Type,
56+
axes::Tuple{AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}},
57+
)
58+
return similar_blocksparse(a, elt, axes)
59+
end
5360
function Base.similar(
5461
a::AnyAbstractBlockSparseArray,
5562
elt::Type,

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ function Broadcast.BroadcastStyle(
3131
return DefaultArrayStyle{N}()
3232
end
3333

34-
function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type)
35-
# TODO: Make sure this handles GPU arrays properly.
34+
function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type, ax)
35+
# TODO: Make this more generic, base it off sure this handles GPU arrays properly.
3636
m = Mapped(bc)
37-
return similar(first(m.args), elt, combine_axes(axes.(m.args)...))
37+
return similar(first(m.args), elt, ax)
3838
end
3939

4040
# Catches cases like `dest .= value` or `dest .= value1 .+ value2`.

test/test_basics.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,24 @@ arrayts = (Array, JLArray)
337337
@test blockstoredlength(a) == 1
338338
@test storedlength(a) == 2 * 4
339339

340+
# Test similar on broadcasted expressions.
341+
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
342+
bc = Broadcast.broadcasted(+, a, a)
343+
a′ = similar(bc, Float32)
344+
@test a′ isa BlockSparseArray{Float32}
345+
@test blocktype(a′) <: arrayt{Float32,2}
346+
@test axes(a) == (blockedrange([2, 3]), blockedrange([3, 4]))
347+
348+
# Test similar on broadcasted expressions with axes specified.
349+
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
350+
bc = Broadcast.broadcasted(+, a, a)
351+
a′ = similar(
352+
bc, Float32, (blockedrange([2, 4]), blockedrange([2, 5]), blockedrange([2, 2]))
353+
)
354+
@test a′ isa BlockSparseArray{Float32}
355+
@test blocktype(a′) <: arrayt{Float32,3}
356+
@test axes(a′) == (blockedrange([2, 4]), blockedrange([2, 5]), blockedrange([2, 2]))
357+
340358
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
341359
@views for b in [Block(1, 2), Block(2, 1)]
342360
a[b] = dev(randn(elt, size(a[b])))

test/test_gradedunitrangesext.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
3838
@test axes(a, 1) isa GradedOneTo
3939
@test axes(view(a, 1:4, 1:4, 1:4, 1:4), 1) isa GradedOneTo
4040

41+
d1 = gradedrange([U1(0) => 2, U1(1) => 2])
42+
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
43+
a = randn_blockdiagonal(elt, (d1, d2, d1, d2))
4144
for b in (a + a, 2 * a)
4245
@test size(b) == (4, 4, 4, 4)
4346
@test blocksize(b) == (2, 2, 2, 2)
@@ -54,11 +57,23 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
5457
@test 2 * Array(a) == b
5558
end
5659

60+
d1 = gradedrange([U1(0) => 2, U1(1) => 2])
61+
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
62+
a = randn_blockdiagonal(elt, (d1, d2, d1, d2))
5763
b = similar(a, ComplexF64)
64+
@test b isa BlockSparseArray{ComplexF64}
5865
@test eltype(b) === ComplexF64
5966

67+
a = BlockSparseVector{Float64}(undef, gradedrange([U1(0) => 1, U1(1) => 1]))
68+
b = similar(a, Float32)
69+
@test b isa BlockSparseVector{Float32}
70+
@test eltype(b) == Float32
71+
6072
# Test mixing graded axes and dense axes
6173
# in addition/broadcasting.
74+
d1 = gradedrange([U1(0) => 2, U1(1) => 2])
75+
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
76+
a = randn_blockdiagonal(elt, (d1, d2, d1, d2))
6277
for b in (a + Array(a), Array(a) + a)
6378
@test size(b) == (4, 4, 4, 4)
6479
@test blocksize(b) == (2, 2, 2, 2)
@@ -73,6 +88,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
7388
@test 2 * Array(a) == b
7489
end
7590

91+
d1 = gradedrange([U1(0) => 2, U1(1) => 2])
92+
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
93+
a = randn_blockdiagonal(elt, (d1, d2, d1, d2))
7694
b = a[2:3, 2:3, 2:3, 2:3]
7795
@test size(b) == (2, 2, 2, 2)
7896
@test blocksize(b) == (2, 2, 2, 2)

0 commit comments

Comments
 (0)