Skip to content

Commit 6fdcbb2

Browse files
committed
Explicitly write out methods
1 parent e13b491 commit 6fdcbb2

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

src/named_array_partition.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,34 @@ end
2626
# fields except through `getfield` and accessor functions.
2727
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)
2828

29-
function Base.similar(x::NamedArrayPartition, args...)
29+
function Base.similar(A::NamedArrayPartition{T, S}) where {T, S}
30+
NamedArrayPartition(ArrayPartition{T, S}(similar.(getfield(A, :array_partition))),
31+
getfield(A, :names_to_indices))
32+
end
33+
34+
# return ArrayPartition when possible, otherwise next best thing of the correct size
35+
function Base.similar(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N}
36+
NamedArrayPartition(
37+
similar(getfield(A, :array_partition), dims), getfield(A, :names_to_indices))
38+
end
39+
40+
# similar array partition of common type
41+
@inline function Base.similar(A::NamedArrayPartition, ::Type{T}) where {T}
42+
NamedArrayPartition(
43+
similar(getfield(A, :array_partition), T), getfield(A, :names_to_indices))
44+
end
45+
46+
# return ArrayPartition when possible, otherwise next best thing of the correct size
47+
function Base.similar(A::NamedArrayPartition, ::Type{T}, dims::NTuple{N, Int}) where {T, N}
48+
NamedArrayPartition(
49+
similar(getfield(A, :array_partition), T, dims), getfield(A, :names_to_indices))
50+
end
51+
52+
# similar array partition with different types
53+
function Base.similar(
54+
A::NamedArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S}
3055
NamedArrayPartition(
31-
similar(getfield(x, :array_partition), args...), getfield(x, :names_to_indices))
56+
similar(getfield(A, :array_partition), T, S, R), getfield(A, :names_to_indices))
3257
end
3358

3459
Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))

test/named_array_partition_tests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using RecursiveArrayTools, Test
44
x = NamedArrayPartition(a = ones(10), b = rand(20))
55
@test typeof(@. sin(x * x^2 / x - 1)) <: NamedArrayPartition
66
@test typeof(x .^ 2) <: NamedArrayPartition
7+
@test typeof(similar(x)) <: NamedArrayPartition
8+
@test typeof(similar(x, Int)) <: NamedArrayPartition
79
@test x.a ones(10)
810
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence
911
@test all(x .== x[1:end])

0 commit comments

Comments
 (0)