Skip to content

Commit 6ef0cf6

Browse files
Merge pull request #508 from ronisbr/master
Support for immutable fields derived from AbstractArray
2 parents 139e5ce + 53f9ab3 commit 6ef0cf6

File tree

3 files changed

+49
-3
lines changed

3 files changed

+49
-3
lines changed

src/data_array.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ Recursively copy fields of `src` to `dest`.
6666
Tf = src.types[i]
6767
qf = Meta.quot(f)
6868

69-
if Tf <: Union{Number,SArray}
69+
if !ArrayInterface.ismutable(Tf)
7070
expressions[i] = :( dest.$f = getfield( src, $qf ) )
7171
elseif Tf <: AbstractArray
7272
expressions[i] = :( recursivecopy!(dest.$f, getfield( src, $qf ) ) )
@@ -110,7 +110,7 @@ end
110110

111111
if f == :x
112112
expressions[i] = :( )
113-
elseif Tf <: Union{Number,SArray}
113+
elseif !ArrayInterface.ismutable(Tf)
114114
expressions[i] = :( dest.$f = getfield( src, $qf ) )
115115
elseif Tf <: AbstractArray
116116
expressions[i] = :( recursivecopy!(dest.$f, getfield( src, $qf ) ) )

test/data_array_tests.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
# Note: This must be the first file executed in the tests.
2+
#
3+
# The following structure is used to test the problem reported at issue #507. We
4+
# need to define it here because, since `recursivecopy!` and `copy_fields!` are
5+
# generated functions, then the `ArrayInterface.isimmutable` method must be
6+
# defined before the first `using DiffEqBase` in a Julia session.
7+
8+
struct Quaternion{T} <: AbstractVector{T}
9+
q0::T
10+
q1::T
11+
q2::T
12+
q3::T
13+
end
14+
15+
using ArrayInterface
16+
ArrayInterface.ismutable(::Type{<:Quaternion}) = false
17+
Base.size(::Quaternion) = 4
18+
119
using DiffEqBase, RecursiveArrayTools, Test
220

321
mutable struct VectorType{T} <: DEDataVector{T}
@@ -105,3 +123,31 @@ end
105123
s0 = SimWorkspace2{Float64}(SVector{2,Float64}(1.0,4.0),1.)
106124
s1 = SimWorkspace2{Float64}(SVector{2,Float64}(2.0,1.0),1.)
107125
s0 .+ s1 == SimWorkspace2{Float64}(SVector{2,Float64}(3.0,5.0),1.)
126+
127+
# Test `recursivecopy!` in immutable structures derived from `AbstractArrays`.
128+
# See issue #507.
129+
mutable struct SimWorkspace3{T} <: DEDataVector{T}
130+
x::Vector{T}
131+
q::Quaternion{T}
132+
end
133+
134+
a = SimWorkspace3([1.0,2.0,3.0], Quaternion(cosd(15), 0.0, 0.0, sind(15)))
135+
b = SimWorkspace3([0.0,0.0,0.0], Quaternion(1.0, 0.0, 0.0, 0.0))
136+
137+
recursivecopy!(b,a)
138+
@test b.x == a.x
139+
@test b.q.q0 == a.q.q0
140+
@test b.q.q1 == a.q.q1
141+
@test b.q.q2 == a.q.q2
142+
@test b.q.q3 == a.q.q3
143+
144+
a = SimWorkspace3([1.0,2.0,3.0], Quaternion(cosd(15), 0.0, 0.0, sind(15)))
145+
b = SimWorkspace3([0.0,0.0,0.0], Quaternion(1.0, 0.0, 0.0, 0.0))
146+
147+
DiffEqBase.copy_fields!(b, a)
148+
@test b.x == [0.0,0.0,0.0]
149+
@test b.q.q0 == a.q.q0
150+
@test b.q.q1 == a.q.q1
151+
@test b.q.q2 == a.q.q2
152+
@test b.q.q3 == a.q.q3
153+

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ const is_TRAVIS = haskey(ENV,"TRAVIS")
66

77
@time begin
88
if GROUP == "All" || GROUP == "Core"
9+
@time @safetestset "Data Arrays" begin include("data_array_tests.jl") end
910
@time @safetestset "Fast Power" begin include("fastpow.jl") end
1011
@time @safetestset "Fast Broadcast" begin include("fastbc.jl") end
1112
@time @safetestset "Number of Parameters Calculation" begin include("numargs_test.jl") end
12-
@time @safetestset "Data Arrays" begin include("data_array_tests.jl") end
1313
@time @safetestset "Existence functions" begin include("existence_functions.jl") end
1414
@time @safetestset "Callbacks" begin include("callbacks.jl") end
1515
@time @safetestset "Plot Variables" begin include("plot_vars.jl") end

0 commit comments

Comments
 (0)