From e544c13785a0473789fd8e7a6bbee1ddce7d3e5a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 16 Jan 2025 12:47:03 -0500 Subject: [PATCH] Fix zero-dimensional conversion --- Project.toml | 2 +- src/abstractblocksparsearray/map.jl | 19 +++++++++++++++++-- test/basics/test_basics.jl | 7 +++++-- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index c7a8b187..11d59835 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.2.8" +version = "0.2.9" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/abstractblocksparsearray/map.jl b/src/abstractblocksparsearray/map.jl index 05529447..52aec400 100644 --- a/src/abstractblocksparsearray/map.jl +++ b/src/abstractblocksparsearray/map.jl @@ -1,6 +1,6 @@ using ArrayLayouts: LayoutArray using BlockArrays: blockisequal -using DerivableInterfaces: @interface, interface +using DerivableInterfaces: @interface, AbstractArrayInterface, interface using LinearAlgebra: Adjoint, Transpose using SparseArraysBase: SparseArraysBase, SparseArrayStyle @@ -49,15 +49,30 @@ function reblock( return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...] end +# `map!` specialized to zero-dimensional inputs. +function map_zero_dim! end + +@interface ::AbstractArrayInterface function map_zero_dim!( + f, a_dest::AbstractArray, a_srcs::AbstractArray... +) + a_dest[] = f.(map(a_src -> a_src[], a_srcs)...) + return a_dest +end + # TODO: Move to `blocksparsearrayinterface/map.jl`. # TODO: Rewrite this so that it takes the blocking structure # made by combining the blocking of the axes (i.e. the blocking that # is used to determine `union_stored_blocked_cartesianindices(...)`). # `reblock` is a partial solution to that, but a bit ad-hoc. ## TODO: Make this an `@interface AbstractBlockSparseArrayInterface` function. -@interface ::AbstractBlockSparseArrayInterface function Base.map!( +@interface interface::AbstractBlockSparseArrayInterface function Base.map!( f, a_dest::AbstractArray, a_srcs::AbstractArray... ) + if iszero(ndims(a_dest)) + @interface interface map_zero_dim!(f, a_dest, a_srcs...) + return a_dest + end + a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs) for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...) BI_dest = blockindexrange(a_dest, I) diff --git a/test/basics/test_basics.jl b/test/basics/test_basics.jl index afe403fe..12b3f8aa 100644 --- a/test/basics/test_basics.jl +++ b/test/basics/test_basics.jl @@ -187,8 +187,9 @@ arrayts = (Array, JLArray) @test iszero(@allowscalar(a[CartesianIndex()])) @test a[Block()] == dev(fill(0)) @test iszero(@allowscalar(a[Block()][])) - # Broken: - ## @test b[Block()[]] == 2 + @test @allowscalar(a[Block()[]]) == 0 + @test Array(a) isa Array{elt,0} + @test Array(a) == fill(0) for b in ( (b = copy(a); @allowscalar b[] = 2; b), (b = copy(a); @allowscalar b[CartesianIndex()] = 2; b), @@ -206,6 +207,8 @@ arrayts = (Array, JLArray) @test b[Block()] == dev(fill(2)) @test @allowscalar(b[Block()][]) == 2 @test @allowscalar(b[Block()[]]) == 2 + @test Array(b) isa Array{elt,0} + @test Array(b) == fill(2) end end