Skip to content

Commit aec22b2

Browse files
authored
add conversion between block arrays (#27)
1 parent 90ef213 commit aec22b2

File tree

4 files changed

+58
-11
lines changed

4 files changed

+58
-11
lines changed

src/BlockArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ include("blocksizes.jl")
1818
include("blockindices.jl")
1919
include("blockarray.jl")
2020
include("pseudo_blockarray.jl")
21+
include("convert.jl")
2122
include("show.jl")
2223

2324

src/blockarray.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ function BlockArray(blocks::Array{R, N}, block_sizes::Vararg{Vector{Int}, N}) wh
2727
return BlockArray{T, N, R}(blocks, BlockSizes(block_sizes...))
2828
end
2929

30-
3130
const BlockMatrix{T, R <: AbstractMatrix{T}} = BlockArray{T, 2, R}
3231
const BlockVector{T, R <: AbstractVector{T}} = BlockArray{T, 1, R}
3332
const BlockVecOrMat{T, R} = Union{BlockMatrix{T, R}, BlockVector{T, R}}
@@ -60,19 +59,21 @@ function BlockArray{T, N, R <: AbstractArray{T,N}}(::Type{R}, block_sizes::Block
6059
BlockArray{T,N,R}(blocks, block_sizes)
6160
end
6261

63-
@generated function BlockArray{T, N}(arr::AbstractArray{T, N}, block_sizes::Vararg{Vector{Int}, N})
64-
return quote
65-
for i in 1:N
66-
if sum(block_sizes[i]) != size(arr, i)
67-
throw(DimensionMismatch("block size for dimension $i: $(block_sizes[i]) does not sum to the array size: $(size(arr, i))"))
68-
end
62+
function BlockArray{T, N}(arr::AbstractArray{T, N}, block_sizes::Vararg{Vector{Int}, N})
63+
for i in 1:N
64+
if sum(block_sizes[i]) != size(arr, i)
65+
throw(DimensionMismatch("block size for dimension $i: $(block_sizes[i]) does not sum to the array size: $(size(arr, i))"))
6966
end
67+
end
68+
BlockArray(arr, BlockSizes(block_sizes...))
69+
end
7070

71-
_block_sizes = BlockSizes(block_sizes...)
72-
block_arr = BlockArray(typeof(arr), _block_sizes)
73-
@nloops $N i i->(1:nblocks(_block_sizes, i)) begin
71+
@generated function BlockArray{T, N}(arr::AbstractArray{T, N}, block_sizes::BlockSizes{N})
72+
return quote
73+
block_arr = BlockArray(typeof(arr), block_sizes)
74+
@nloops $N i i->(1:nblocks(block_sizes, i)) begin
7475
block_index = @ntuple $N i
75-
indices = globalrange(_block_sizes, block_index)
76+
indices = globalrange(block_sizes, block_index)
7677
setblock!(block_arr, arr[indices...], block_index...)
7778
end
7879

src/convert.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
function Base.convert{T, T2, N}(::Type{BlockArray{T, N}}, A::PseudoBlockArray{T2, N})
2+
BlockArray(convert(Array{T, N}, Array(A)), A.block_sizes)
3+
end
4+
Base.convert{T, T2, N, R}(::Type{BlockArray{T, N, R}}, A::PseudoBlockArray{T2, N}) = convert(BlockArray{T, N}, A)
5+
Base.convert{T1, T2, N}(::Type{BlockArray{T1}}, A::PseudoBlockArray{T2, N}) = convert(BlockArray{T1, N}, A)
6+
Base.convert{T, N}(::Type{BlockArray}, A::PseudoBlockArray{T, N}) = convert(BlockArray{T, N}, A)
7+
BlockArray(A::BlockArray) = convert(BlockArray, A)
8+
9+
function Base.convert{T, T2, N}(::Type{PseudoBlockArray{T, N}}, A::BlockArray{T2, N})
10+
PseudoBlockArray(convert(Array{T, N}, Array(A)), A.block_sizes)
11+
end
12+
Base.convert{T, T2, N, R}(::Type{PseudoBlockArray{T, N, R}}, A::BlockArray{T2, N}) = convert(PseudoBlockArray{T, N}, A)
13+
Base.convert{T, N}(::Type{PseudoBlockArray}, A::BlockArray{T, N}) = convert(PseudoBlockArray{T, N}, A)
14+
Base.convert{T1, T2, N}(::Type{PseudoBlockArray{T1}}, A::BlockArray{T2, N}) = convert(PseudoBlockArray{T1, N}, A)
15+
PseudoBlockArray(A::BlockArray) = convert(PseudoBlockArray, A)

test/test_blockarrays.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,36 @@ end
111111
end
112112
end
113113

114+
@testset "convert" begin
115+
# Could probably be DRY'd.
116+
A = PseudoBlockArray(rand(2,3), [1,1], [2,1])
117+
C = convert(BlockArray, A)
118+
@test C == A == BlockArray(A)
119+
@test eltype(C) == eltype(A)
120+
121+
C = convert(BlockArray{Float32}, A)
122+
@test C A BlockArray(A)
123+
@test eltype(C) == Float32
124+
125+
C = convert(BlockArray{Float32, 2}, A)
126+
@test C A BlockArray(A)
127+
@test eltype(C) == Float32
128+
129+
130+
A = BlockArray(rand(2,3), [1,1], [2,1])
131+
C = convert(PseudoBlockArray, A)
132+
@test C == A == PseudoBlockArray(A)
133+
@test eltype(C) == eltype(A)
134+
135+
C = convert(PseudoBlockArray{Float32}, A)
136+
@test C A PseudoBlockArray(A)
137+
@test eltype(C) == Float32
138+
139+
C = convert(PseudoBlockArray{Float32, 2}, A)
140+
@test C A PseudoBlockArray(A)
141+
@test eltype(C) == Float32
142+
end
143+
114144
@testset "string" begin
115145
A = BlockArray(rand(4, 5), [1,3], [2,3]);
116146
buf = IOBuffer()

0 commit comments

Comments
 (0)