diff --git a/Project.toml b/Project.toml index 936a5c21..f10faa8b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.2.20" +version = "0.2.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index 57ebe783..d8ab5ec8 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -31,10 +31,10 @@ function Broadcast.BroadcastStyle( return DefaultArrayStyle{N}() end -function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type) - # TODO: Make sure this handles GPU arrays properly. +function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type, ax) + # TODO: Make this more generic, base it off sure this handles GPU arrays properly. m = Mapped(bc) - return similar(first(m.args), elt, combine_axes(axes.(m.args)...)) + return similar(first(m.args), elt, ax) end # Catches cases like `dest .= value` or `dest .= value1 .+ value2`. diff --git a/test/test_basics.jl b/test/test_basics.jl index de982d6b..e1ac5890 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -337,6 +337,24 @@ arrayts = (Array, JLArray) @test blockstoredlength(a) == 1 @test storedlength(a) == 2 * 4 + # Test similar on broadcasted expressions. + a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))) + bc = Broadcast.broadcasted(+, a, a) + a′ = similar(bc, Float32) + @test a′ isa BlockSparseArray{Float32} + @test blocktype(a′) <: arrayt{Float32,2} + @test axes(a) == (blockedrange([2, 3]), blockedrange([3, 4])) + + # Test similar on broadcasted expressions with axes specified. + a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))) + bc = Broadcast.broadcasted(+, a, a) + a′ = similar( + bc, Float32, (blockedrange([2, 4]), blockedrange([2, 5]), blockedrange([2, 2])) + ) + @test a′ isa BlockSparseArray{Float32} + @test blocktype(a′) <: arrayt{Float32,3} + @test axes(a′) == (blockedrange([2, 4]), blockedrange([2, 5]), blockedrange([2, 2])) + a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))) @views for b in [Block(1, 2), Block(2, 1)] a[b] = dev(randn(elt, size(a[b])))