Skip to content

Commit 42ec520

Browse files
authored
Better overloading of similar of block sparse broadcasting expressions (#61)
1 parent 6211e57 commit 42ec520

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.20"
4+
version = "0.2.21"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ function Broadcast.BroadcastStyle(
3131
return DefaultArrayStyle{N}()
3232
end
3333

34-
function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type)
35-
# TODO: Make sure this handles GPU arrays properly.
34+
function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type, ax)
35+
# TODO: Make this more generic, base it off sure this handles GPU arrays properly.
3636
m = Mapped(bc)
37-
return similar(first(m.args), elt, combine_axes(axes.(m.args)...))
37+
return similar(first(m.args), elt, ax)
3838
end
3939

4040
# Catches cases like `dest .= value` or `dest .= value1 .+ value2`.

test/test_basics.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,24 @@ arrayts = (Array, JLArray)
337337
@test blockstoredlength(a) == 1
338338
@test storedlength(a) == 2 * 4
339339

340+
# Test similar on broadcasted expressions.
341+
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
342+
bc = Broadcast.broadcasted(+, a, a)
343+
a′ = similar(bc, Float32)
344+
@test a′ isa BlockSparseArray{Float32}
345+
@test blocktype(a′) <: arrayt{Float32,2}
346+
@test axes(a) == (blockedrange([2, 3]), blockedrange([3, 4]))
347+
348+
# Test similar on broadcasted expressions with axes specified.
349+
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
350+
bc = Broadcast.broadcasted(+, a, a)
351+
a′ = similar(
352+
bc, Float32, (blockedrange([2, 4]), blockedrange([2, 5]), blockedrange([2, 2]))
353+
)
354+
@test a′ isa BlockSparseArray{Float32}
355+
@test blocktype(a′) <: arrayt{Float32,3}
356+
@test axes(a′) == (blockedrange([2, 4]), blockedrange([2, 5]), blockedrange([2, 2]))
357+
340358
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
341359
@views for b in [Block(1, 2), Block(2, 1)]
342360
a[b] = dev(randn(elt, size(a[b])))

0 commit comments

Comments
 (0)