Skip to content

Commit 587d6fa

Browse files
committed
Merge branch 'main' into adapt_blockedperm
2 parents ba45d6b + 42ec520 commit 587d6fa

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ function Broadcast.BroadcastStyle(
3131
return DefaultArrayStyle{N}()
3232
end
3333

34-
function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type)
35-
# TODO: Make sure this handles GPU arrays properly.
34+
function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type, ax)
35+
# TODO: Make this more generic, base it off sure this handles GPU arrays properly.
3636
m = Mapped(bc)
37-
return similar(first(m.args), elt, combine_axes(axes.(m.args)...))
37+
return similar(first(m.args), elt, ax)
3838
end
3939

4040
# Catches cases like `dest .= value` or `dest .= value1 .+ value2`.

test/test_basics.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,24 @@ arrayts = (Array, JLArray)
337337
@test blockstoredlength(a) == 1
338338
@test storedlength(a) == 2 * 4
339339

340+
# Test similar on broadcasted expressions.
341+
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
342+
bc = Broadcast.broadcasted(+, a, a)
343+
a′ = similar(bc, Float32)
344+
@test a′ isa BlockSparseArray{Float32}
345+
@test blocktype(a′) <: arrayt{Float32,2}
346+
@test axes(a) == (blockedrange([2, 3]), blockedrange([3, 4]))
347+
348+
# Test similar on broadcasted expressions with axes specified.
349+
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
350+
bc = Broadcast.broadcasted(+, a, a)
351+
a′ = similar(
352+
bc, Float32, (blockedrange([2, 4]), blockedrange([2, 5]), blockedrange([2, 2]))
353+
)
354+
@test a′ isa BlockSparseArray{Float32}
355+
@test blocktype(a′) <: arrayt{Float32,3}
356+
@test axes(a′) == (blockedrange([2, 4]), blockedrange([2, 5]), blockedrange([2, 2]))
357+
340358
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
341359
@views for b in [Block(1, 2), Block(2, 1)]
342360
a[b] = dev(randn(elt, size(a[b])))

0 commit comments

Comments
 (0)