Skip to content

Commit 01d034f

Browse files
authored
More customizability (#126)
1 parent c38b382 commit 01d034f

File tree

7 files changed

+55
-7
lines changed

7 files changed

+55
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.6.5"
4+
version = "0.6.6"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using BlockArrays: BlockArrays, AbstractBlockedUnitRange, Block, blockedrange, blocklasts
2+
3+
struct BlockUnitRange{T,B,CS,R<:AbstractBlockedUnitRange{T,CS}} <:
4+
AbstractBlockedUnitRange{T,CS}
5+
r::R
6+
eachblockaxis::B
7+
end
8+
function blockrange(eachblockaxis)
9+
return BlockUnitRange(blockedrange(length.(eachblockaxis)), eachblockaxis)
10+
end
11+
Base.first(r::BlockUnitRange) = first(r.r)
12+
Base.last(r::BlockUnitRange) = last(r.r)
13+
BlockArrays.blocklasts(r::BlockUnitRange) = blocklasts(r.r)
14+
eachblockaxis(r::BlockUnitRange) = r.eachblockaxis
15+
function Base.getindex(r::BlockUnitRange, I::Block{1})
16+
return eachblockaxis(r)[Int(I)] .+ (first(r.r[I]) - 1)
17+
end
18+
19+
function BlockArrays.combine_blockaxes(r1::BlockUnitRange, r2::BlockUnitRange)
20+
if eachblockaxis(r1) eachblockaxis(r2)
21+
return throw(ArgumentError("BlockUnitRanges must have the same block axes"))
22+
end
23+
return r1
24+
end

src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export BlockSparseArray,
1010

1111
# possible upstream contributions
1212
include("BlockArraysExtensions/blockedunitrange.jl")
13+
include("BlockArraysExtensions/blockrange.jl")
1314
include("BlockArraysExtensions/BlockArraysExtensions.jl")
1415

1516
# interface functions that don't have to specialize

src/abstractblocksparsearray/map.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,10 @@ end
110110
function Base.isreal(a::AnyAbstractBlockSparseArray)
111111
return @interface interface(a) isreal(a)
112112
end
113+
114+
function Base.:*(x::Number, a::AnyAbstractBlockSparseArray)
115+
return map(Base.Fix1(*, x), a)
116+
end
117+
function Base.:*(a::AnyAbstractBlockSparseArray, x::Number)
118+
return map(Base.Fix2(*, x), a)
119+
end

src/blocksparsearrayinterface/getunstoredblock.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ end
1010
) where {N}
1111
# TODO: Make sure this works for sparse or block sparse blocks, immutable
1212
# blocks, diagonal blocks, etc.!
13-
b_size = ntuple(ndims(a)) do d
14-
return length(f.axes[d][Block(I[d])])
13+
b_ax = ntuple(ndims(a)) do d
14+
return only(axes(f.axes[d][Block(I[d])]))
1515
end
16-
b = similar(eltype(a), b_size)
16+
b = similar(eltype(a), b_ax)
1717
zero!(b)
1818
return b
1919
end

test/test_basics.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ using BlockSparseArrays:
2222
BlockSparseMatrix,
2323
BlockSparseVector,
2424
BlockView,
25-
blockstoredlength,
2625
blockreshape,
27-
eachblockstoredindex,
28-
eachstoredblock,
26+
blockstoredlength,
2927
blockstype,
3028
blocktype,
29+
eachblockstoredindex,
30+
eachstoredblock,
3131
sparsemortar,
3232
view!
3333
using GPUArraysCore: @allowscalar

test/test_blockrange.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using BlockArrays: Block, blocklength
2+
using BlockSparseArrays: blockrange, eachblockaxis
3+
using Test: @test, @testset
4+
5+
@testset "blockrange" begin
6+
r = blockrange(AbstractUnitRange{Int}[Base.OneTo(3), 1:4])
7+
@test eachblockaxis(r) == [Base.OneTo(3), 1:4]
8+
@test eachblockaxis(r)[1] === Base.OneTo(3)
9+
@test eachblockaxis(r)[2] === 1:4
10+
@test r[Block(1)] == 1:3
11+
@test r[Block(2)] == 4:7
12+
@test first(r) == 1
13+
@test last(r) == 7
14+
@test blocklength(r) == 2
15+
@test r == 1:7
16+
end

0 commit comments

Comments
 (0)