diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl index d0737e38..de8fa91a 100644 --- a/src/named_array_partition.jl +++ b/src/named_array_partition.jl @@ -26,6 +26,36 @@ end # fields except through `getfield` and accessor functions. ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition) +function Base.similar(A::NamedArrayPartition) + NamedArrayPartition( + similar(getfield(A, :array_partition)), getfield(A, :names_to_indices)) +end + +# return ArrayPartition when possible, otherwise next best thing of the correct size +function Base.similar(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} + NamedArrayPartition( + similar(getfield(A, :array_partition), dims), getfield(A, :names_to_indices)) +end + +# similar array partition of common type +@inline function Base.similar(A::NamedArrayPartition, ::Type{T}) where {T} + NamedArrayPartition( + similar(getfield(A, :array_partition), T), getfield(A, :names_to_indices)) +end + +# return ArrayPartition when possible, otherwise next best thing of the correct size +function Base.similar(A::NamedArrayPartition, ::Type{T}, dims::NTuple{N, Int}) where {T, N} + NamedArrayPartition( + similar(getfield(A, :array_partition), T, dims), getfield(A, :names_to_indices)) +end + +# similar array partition with different types +function Base.similar( + A::NamedArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S} + NamedArrayPartition( + similar(getfield(A, :array_partition), T, S, R), getfield(A, :names_to_indices)) +end + Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x)) function Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} diff --git a/src/utils.jl b/src/utils.jl index 58945065..8c34c316 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,7 +1,6 @@ unrolled_foreach!(f, t::Tuple) = (f(t[1]); unrolled_foreach!(f, Base.tail(t))) unrolled_foreach!(f, ::Tuple{}) = nothing - """ ```julia recursivecopy(a::Union{AbstractArray{T, N}, AbstractVectorOfArray{T, N}}) @@ -131,7 +130,6 @@ function recursivefill!(bs::AbstractVectorOfArray{T, N}, end end - for type in [AbstractArray, AbstractVectorOfArray] @eval function recursivefill!(b::$type{T, N}, a::T2) where {T <: Enum, T2 <: Enum, N} fill!(b, a) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index cf32d89d..408d5bf5 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -177,10 +177,11 @@ Base.parent(vec::VectorOfArray) = vec.u #### 2-argument # first element representative -function DiffEqArray(vec::AbstractVector, ts::AbstractVector; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) +function DiffEqArray(vec::AbstractVector, ts::AbstractVector; discretes = nothing, + variables = nothing, parameters = nothing, independent_variables = nothing) sys = SymbolCache(something(variables, []), - something(parameters, []), - something(independent_variables, [])) + something(parameters, []), + something(independent_variables, [])) _size = size(vec[1]) T = eltype(vec[1]) return DiffEqArray{ @@ -199,10 +200,12 @@ function DiffEqArray(vec::AbstractVector, ts::AbstractVector; discretes = nothin end # T and N from type -function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}} +function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector; + discretes = nothing, variables = nothing, parameters = nothing, + independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}} sys = SymbolCache(something(variables, []), - something(parameters, []), - something(independent_variables, [])) + something(parameters, []), + something(independent_variables, [])) return DiffEqArray{ eltype(eltype(vec)), N + 1, @@ -221,7 +224,8 @@ end #### 3-argument # NTuple, T from type -function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}; discretes = nothing) where {T, N} +function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, + ::NTuple{N, Int}; discretes = nothing) where {T, N} DiffEqArray{ eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, typeof(discretes)}( vec, @@ -232,8 +236,11 @@ function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int end # NTuple parameter -function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, Int}; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}, N2} - DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec, +function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, Int}; + discretes = nothing) where {T, N, VT <: AbstractArray{T, N}, N2} + DiffEqArray{ + eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}( + vec, ts, p, nothing, @@ -241,10 +248,11 @@ function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, end # first element representative -function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) +function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p; discretes = nothing, + variables = nothing, parameters = nothing, independent_variables = nothing) sys = SymbolCache(something(variables, []), - something(parameters, []), - something(independent_variables, [])) + something(parameters, []), + something(independent_variables, [])) _size = size(vec[1]) T = eltype(vec[1]) return DiffEqArray{ @@ -263,11 +271,14 @@ function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p; discretes = not end # T and N from type -function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p; discretes = nothing, variables = nothing, parameters = nothing, independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}} +function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p; + discretes = nothing, variables = nothing, parameters = nothing, + independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}} sys = SymbolCache(something(variables, []), - something(parameters, []), - something(independent_variables, [])) - DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec, + something(parameters, []), + something(independent_variables, [])) + DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), + typeof(p), typeof(sys), typeof(discretes)}(vec, ts, p, sys, @@ -277,7 +288,8 @@ end #### 4-argument # NTuple, T from type -function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}, p; discretes = nothing) where {T, N} +function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, + ::NTuple{N, Int}, p; discretes = nothing) where {T, N} DiffEqArray{ eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}( vec, @@ -288,8 +300,10 @@ function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int end # NTuple parameter -function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, Int}, sys; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}, N2} - DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec, +function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p::NTuple{N2, Int}, sys; + discretes = nothing) where {T, N, VT <: AbstractArray{T, N}, N2} + DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), + typeof(p), typeof(sys), typeof(discretes)}(vec, ts, p, sys, @@ -316,8 +330,10 @@ function DiffEqArray(vec::AbstractVector, ts::AbstractVector, p, sys; discretes end # T and N from type -function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p, sys; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}} - DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec, +function DiffEqArray(vec::AbstractVector{VT}, ts::AbstractVector, p, sys; + discretes = nothing) where {T, N, VT <: AbstractArray{T, N}} + DiffEqArray{eltype(T), N + 1, typeof(vec), typeof(ts), + typeof(p), typeof(sys), typeof(discretes)}(vec, ts, p, sys, @@ -327,7 +343,8 @@ end #### 5-argument # NTuple, T from type -function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}, p, sys; discretes = nothing) where {T, N} +function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, + ::NTuple{N, Int}, p, sys; discretes = nothing) where {T, N} DiffEqArray{ eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}( vec, @@ -942,7 +959,7 @@ end VectorOfArray(rewrap(parent, u)) end -rewrap(::Array,u) = u +rewrap(::Array, u) = u rewrap(parent, u) = convert(typeof(parent), u) for (type, N_expr) in [ diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 557b65be..5d11fab0 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -27,7 +27,8 @@ sol_new = DiffEqArray(sol.u[1:10], @test all(isequal.(all_variable_symbols(sol), all_variable_symbols(sol_new))) @test all(isequal.(all_variable_symbols(sol), [x, RHS])) @test all(isequal.(all_symbols(sol), all_symbols(sol_new))) -@test all([any(isequal(sym), all_symbols(sol)) for sym in [x, RHS, τ, t, Initial(x), Initial(RHS)]]) +@test all([any(isequal(sym), all_symbols(sol)) + for sym in [x, RHS, τ, t, Initial(x), Initial(RHS)]]) @test sol[solvedvariables] == sol[[x]] @test sol_new[solvedvariables] == sol_new[[x]] @test sol[allvariables] == sol[[x, RHS]] diff --git a/test/gpu/arraypartition_gpu.jl b/test/gpu/arraypartition_gpu.jl index 3b335855..80f9a8d6 100644 --- a/test/gpu/arraypartition_gpu.jl +++ b/test/gpu/arraypartition_gpu.jl @@ -1,7 +1,6 @@ using RecursiveArrayTools, CUDA, Test CUDA.allowscalar(false) - # Test indexing with colon a = (CUDA.zeros(5), CUDA.zeros(5)) pA = ArrayPartition(a) diff --git a/test/named_array_partition_tests.jl b/test/named_array_partition_tests.jl index d8164edf..d5647bad 100644 --- a/test/named_array_partition_tests.jl +++ b/test/named_array_partition_tests.jl @@ -4,6 +4,8 @@ using RecursiveArrayTools, Test x = NamedArrayPartition(a = ones(10), b = rand(20)) @test typeof(@. sin(x * x^2 / x - 1)) <: NamedArrayPartition @test typeof(x .^ 2) <: NamedArrayPartition + @test typeof(similar(x)) <: NamedArrayPartition + @test typeof(similar(x, Int)) <: NamedArrayPartition @test x.a ≈ ones(10) @test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence @test all(x .== x[1:end])