Skip to content

Commit be32f6b

Browse files
Merge pull request #193 from mateuszbaran/mbaran/similar-array-partition-fix
Fix shaped similar for ArrayPartition
2 parents 1259985 + 181dbb6 commit be32f6b

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

src/array_partition.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,29 @@ end
1919

2020
Base.similar(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(similar.(A.x))
2121

22-
# ignore dims since array partitions are vectors
23-
Base.similar(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = similar(A)
22+
# return ArrayPartition when possible, otherwise next best thing of the correct size
23+
function Base.similar(A::ArrayPartition, dims::NTuple{N,Int}) where {N}
24+
if dims == size(A)
25+
return similar(A)
26+
else
27+
return similar(A.x[1], eltype(A), dims)
28+
end
29+
end
2430

2531
# similar array partition of common type
2632
@inline function Base.similar(A::ArrayPartition, ::Type{T}) where {T}
2733
N = npartitions(A)
2834
ArrayPartition(i->similar(A.x[i], T), N)
2935
end
3036

31-
# ignore dims since array partitions are vectors
32-
Base.similar(A::ArrayPartition, ::Type{T}, dims::NTuple{N,Int}) where {T,N} = similar(A, T)
37+
# return ArrayPartition when possible, otherwise next best thing of the correct size
38+
function Base.similar(A::ArrayPartition, ::Type{T}, dims::NTuple{N,Int}) where {T,N}
39+
if dims == size(A)
40+
return similar(A, T)
41+
else
42+
return similar(A.x[1], T, dims)
43+
end
44+
end
3345

3446
# similar array partition with different types
3547
function Base.similar(A::ArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S}

test/partitions_test.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ x = ArrayPartition([1, 2], [3.0, 4.0])
6161

6262
# similar partitions
6363
@inferred similar(x)
64-
@inferred similar(x, (2, 2))
64+
@test similar(x, (4,)) isa ArrayPartition{Float64}
65+
@test (@inferred similar(x, (2, 2))) isa AbstractMatrix{Float64}
6566
@inferred similar(x, Int)
66-
@inferred similar(x, Int, (2, 2))
67+
@test similar(x, Int, (4,)) isa ArrayPartition{Int}
68+
@test (@inferred similar(x, Int, (2, 2))) isa AbstractMatrix{Int}
6769
# @inferred similar(x, Int, Float64)
6870

6971
# zero

0 commit comments

Comments
 (0)