Skip to content

Commit 6e779ac

Browse files
Merge pull request #139 from ranocha/hr/StructArrays
fix recursivecopy
2 parents c5f8fa7 + 036fafc commit 6e779ac

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2727
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
2828
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2929
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
30+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
3031
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3132
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
3233
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3334

3435
[targets]
35-
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Test", "Unitful", "Random", "Zygote"]
36+
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Test", "Unitful", "Random", "StructArrays", "Zygote"]

src/utils.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@ function recursivecopy(a::AbstractArray{T,N}) where {T<:Number,N}
77
end
88

99
function recursivecopy(a::AbstractArray{T,N}) where {T<:AbstractArray,N}
10-
map(recursivecopy,a)
10+
if ArrayInterface.ismutable(a)
11+
b = similar(a)
12+
map!(recursivecopy, b, a)
13+
else
14+
ArrayInterface.restructure(a, map(recursivecopy, a))
15+
end
1116
end
1217

1318
function recursivecopy!(b::AbstractArray{T,N},a::AbstractArray{T2,N}) where {T<:StaticArray,T2<:StaticArray,N}

test/copy_static_array_test.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Test, RecursiveArrayTools, StaticArrays
1+
using Test, RecursiveArrayTools, StaticArrays, StructArrays
22

33
struct ImmutableFV <: FieldVector{2,Float64}
44
a::Float64
@@ -64,3 +64,19 @@ a[2][1] *= 5
6464
b[1] = 2*b[1]
6565
copyat_or_push!(a, 2, b[1])
6666
@test a[2] == b[1]
67+
68+
# StructArray of Immutable FieldVector
69+
a = StructArray([ImmutableFV(1., 2.)])
70+
b = recursivecopy(a)
71+
@test typeof(a) == typeof(b)
72+
@test a[1] == b[1]
73+
a[1] *= 2
74+
@test a[1] != b[1]
75+
76+
# StructArray of Mutable FieldVector
77+
a = StructArray([MutableFV(1., 2.)])
78+
b = recursivecopy(a)
79+
@test typeof(a) == typeof(b)
80+
@test a[1] == b[1]
81+
a[1] *= 2
82+
@test a[1] != b[1]

0 commit comments

Comments
 (0)