diff --git a/Project.toml b/Project.toml index 973c7045..5a019d19 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.6.5" +version = "0.6.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/BlockArraysExtensions/blockrange.jl b/src/BlockArraysExtensions/blockrange.jl new file mode 100644 index 00000000..7edd0133 --- /dev/null +++ b/src/BlockArraysExtensions/blockrange.jl @@ -0,0 +1,24 @@ +using BlockArrays: BlockArrays, AbstractBlockedUnitRange, Block, blockedrange, blocklasts + +struct BlockUnitRange{T,B,CS,R<:AbstractBlockedUnitRange{T,CS}} <: + AbstractBlockedUnitRange{T,CS} + r::R + eachblockaxis::B +end +function blockrange(eachblockaxis) + return BlockUnitRange(blockedrange(length.(eachblockaxis)), eachblockaxis) +end +Base.first(r::BlockUnitRange) = first(r.r) +Base.last(r::BlockUnitRange) = last(r.r) +BlockArrays.blocklasts(r::BlockUnitRange) = blocklasts(r.r) +eachblockaxis(r::BlockUnitRange) = r.eachblockaxis +function Base.getindex(r::BlockUnitRange, I::Block{1}) + return eachblockaxis(r)[Int(I)] .+ (first(r.r[I]) - 1) +end + +function BlockArrays.combine_blockaxes(r1::BlockUnitRange, r2::BlockUnitRange) + if eachblockaxis(r1) ≠ eachblockaxis(r2) + return throw(ArgumentError("BlockUnitRanges must have the same block axes")) + end + return r1 +end diff --git a/src/BlockSparseArrays.jl b/src/BlockSparseArrays.jl index 0c8e7716..6470fc61 100644 --- a/src/BlockSparseArrays.jl +++ b/src/BlockSparseArrays.jl @@ -10,6 +10,7 @@ export BlockSparseArray, # possible upstream contributions include("BlockArraysExtensions/blockedunitrange.jl") +include("BlockArraysExtensions/blockrange.jl") include("BlockArraysExtensions/BlockArraysExtensions.jl") # interface functions that don't have to specialize diff --git a/src/abstractblocksparsearray/map.jl b/src/abstractblocksparsearray/map.jl index 724b0ffc..4dcec66f 100644 --- a/src/abstractblocksparsearray/map.jl +++ b/src/abstractblocksparsearray/map.jl @@ -110,3 +110,10 @@ end function Base.isreal(a::AnyAbstractBlockSparseArray) return @interface interface(a) isreal(a) end + +function Base.:*(x::Number, a::AnyAbstractBlockSparseArray) + return map(Base.Fix1(*, x), a) +end +function Base.:*(a::AnyAbstractBlockSparseArray, x::Number) + return map(Base.Fix2(*, x), a) +end diff --git a/src/blocksparsearrayinterface/getunstoredblock.jl b/src/blocksparsearrayinterface/getunstoredblock.jl index 6273bed2..36caddea 100644 --- a/src/blocksparsearrayinterface/getunstoredblock.jl +++ b/src/blocksparsearrayinterface/getunstoredblock.jl @@ -10,10 +10,10 @@ end ) where {N} # TODO: Make sure this works for sparse or block sparse blocks, immutable # blocks, diagonal blocks, etc.! - b_size = ntuple(ndims(a)) do d - return length(f.axes[d][Block(I[d])]) + b_ax = ntuple(ndims(a)) do d + return only(axes(f.axes[d][Block(I[d])])) end - b = similar(eltype(a), b_size) + b = similar(eltype(a), b_ax) zero!(b) return b end diff --git a/test/test_basics.jl b/test/test_basics.jl index bc220d2b..82e86ea7 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -22,12 +22,12 @@ using BlockSparseArrays: BlockSparseMatrix, BlockSparseVector, BlockView, - blockstoredlength, blockreshape, - eachblockstoredindex, - eachstoredblock, + blockstoredlength, blockstype, blocktype, + eachblockstoredindex, + eachstoredblock, sparsemortar, view! using GPUArraysCore: @allowscalar diff --git a/test/test_blockrange.jl b/test/test_blockrange.jl new file mode 100644 index 00000000..35675655 --- /dev/null +++ b/test/test_blockrange.jl @@ -0,0 +1,16 @@ +using BlockArrays: Block, blocklength +using BlockSparseArrays: blockrange, eachblockaxis +using Test: @test, @testset + +@testset "blockrange" begin + r = blockrange(AbstractUnitRange{Int}[Base.OneTo(3), 1:4]) + @test eachblockaxis(r) == [Base.OneTo(3), 1:4] + @test eachblockaxis(r)[1] === Base.OneTo(3) + @test eachblockaxis(r)[2] === 1:4 + @test r[Block(1)] == 1:3 + @test r[Block(2)] == 4:7 + @test first(r) == 1 + @test last(r) == 7 + @test blocklength(r) == 2 + @test r == 1:7 +end