Skip to content

Commit 53f9ab3

Browse files
committed
Support for immutable fields derived from AbstractArray
The interface provided by ArrayInterface.jl is now used to check whether the fields of the structure derived from `DEDataArray` are immutable. Hence, we can now have fields derived from `AbstractArray` that are immutable. Notice that it was required to move the file with the related tests to the beginning of the set. Since `recursivecopy!` and `copy_fields!` are both generated functions, then the structure `Quaternion`, which was created specifically for this test, together with the definition of the method `ArrayInterface.ismutable(::Type{<:Quaternion})` must be placed before the first call to `using DiffEqBase` in the current Julia session. Closes #507
1 parent 139e5ce commit 53f9ab3

File tree

3 files changed

+49
-3
lines changed

3 files changed

+49
-3
lines changed

src/data_array.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ Recursively copy fields of `src` to `dest`.
6666
Tf = src.types[i]
6767
qf = Meta.quot(f)
6868

69-
if Tf <: Union{Number,SArray}
69+
if !ArrayInterface.ismutable(Tf)
7070
expressions[i] = :( dest.$f = getfield( src, $qf ) )
7171
elseif Tf <: AbstractArray
7272
expressions[i] = :( recursivecopy!(dest.$f, getfield( src, $qf ) ) )
@@ -110,7 +110,7 @@ end
110110

111111
if f == :x
112112
expressions[i] = :( )
113-
elseif Tf <: Union{Number,SArray}
113+
elseif !ArrayInterface.ismutable(Tf)
114114
expressions[i] = :( dest.$f = getfield( src, $qf ) )
115115
elseif Tf <: AbstractArray
116116
expressions[i] = :( recursivecopy!(dest.$f, getfield( src, $qf ) ) )

test/data_array_tests.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
# Note: This must be the first file executed in the tests.
2+
#
3+
# The following structure is used to test the problem reported at issue #507. We
4+
# need to define it here because, since `recursivecopy!` and `copy_fields!` are
5+
# generated functions, then the `ArrayInterface.isimmutable` method must be
6+
# defined before the first `using DiffEqBase` in a Julia session.
7+
8+
struct Quaternion{T} <: AbstractVector{T}
9+
q0::T
10+
q1::T
11+
q2::T
12+
q3::T
13+
end
14+
15+
using ArrayInterface
16+
ArrayInterface.ismutable(::Type{<:Quaternion}) = false
17+
Base.size(::Quaternion) = 4
18+
119
using DiffEqBase, RecursiveArrayTools, Test
220

321
mutable struct VectorType{T} <: DEDataVector{T}
@@ -105,3 +123,31 @@ end
105123
s0 = SimWorkspace2{Float64}(SVector{2,Float64}(1.0,4.0),1.)
106124
s1 = SimWorkspace2{Float64}(SVector{2,Float64}(2.0,1.0),1.)
107125
s0 .+ s1 == SimWorkspace2{Float64}(SVector{2,Float64}(3.0,5.0),1.)
126+
127+
# Test `recursivecopy!` in immutable structures derived from `AbstractArrays`.
128+
# See issue #507.
129+
mutable struct SimWorkspace3{T} <: DEDataVector{T}
130+
x::Vector{T}
131+
q::Quaternion{T}
132+
end
133+
134+
a = SimWorkspace3([1.0,2.0,3.0], Quaternion(cosd(15), 0.0, 0.0, sind(15)))
135+
b = SimWorkspace3([0.0,0.0,0.0], Quaternion(1.0, 0.0, 0.0, 0.0))
136+
137+
recursivecopy!(b,a)
138+
@test b.x == a.x
139+
@test b.q.q0 == a.q.q0
140+
@test b.q.q1 == a.q.q1
141+
@test b.q.q2 == a.q.q2
142+
@test b.q.q3 == a.q.q3
143+
144+
a = SimWorkspace3([1.0,2.0,3.0], Quaternion(cosd(15), 0.0, 0.0, sind(15)))
145+
b = SimWorkspace3([0.0,0.0,0.0], Quaternion(1.0, 0.0, 0.0, 0.0))
146+
147+
DiffEqBase.copy_fields!(b, a)
148+
@test b.x == [0.0,0.0,0.0]
149+
@test b.q.q0 == a.q.q0
150+
@test b.q.q1 == a.q.q1
151+
@test b.q.q2 == a.q.q2
152+
@test b.q.q3 == a.q.q3
153+

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ const is_TRAVIS = haskey(ENV,"TRAVIS")
66

77
@time begin
88
if GROUP == "All" || GROUP == "Core"
9+
@time @safetestset "Data Arrays" begin include("data_array_tests.jl") end
910
@time @safetestset "Fast Power" begin include("fastpow.jl") end
1011
@time @safetestset "Fast Broadcast" begin include("fastbc.jl") end
1112
@time @safetestset "Number of Parameters Calculation" begin include("numargs_test.jl") end
12-
@time @safetestset "Data Arrays" begin include("data_array_tests.jl") end
1313
@time @safetestset "Existence functions" begin include("existence_functions.jl") end
1414
@time @safetestset "Callbacks" begin include("callbacks.jl") end
1515
@time @safetestset "Plot Variables" begin include("plot_vars.jl") end

0 commit comments

Comments
 (0)