@@ -3,7 +3,8 @@ module RecursiveArrayToolsStructArraysExt
33import RecursiveArrayTools, StructArrays
44RecursiveArrayTools. rewrap (:: StructArrays.StructArray , u) = StructArrays. StructArray (u)
55
6- using RecursiveArrayTools: VectorOfArray
6+ using RecursiveArrayTools: VectorOfArray, VectorOfArrayStyle, ArrayInterface, unpack_voa,
7+ narrays, StaticArraysCore
78using StructArrays: StructArray
89
910const VectorOfStructArray{T, N} = VectorOfArray{T, N, <: StructArray }
@@ -17,11 +18,45 @@ const VectorOfStructArray{T, N} = VectorOfArray{T, N, <:StructArray}
1718#
1819# To avoid this, we can materialize a struct entry, modify it, and then use `setindex!`
1920# with the modified struct entry.
21+ #
2022function Base. setindex! (VA:: VectorOfStructArray{T, N} , v,
2123 I:: Int... ) where {T, N}
2224 u_I = VA. u[I[end ]]
2325 u_I[Base. front (I)... ] = v
2426 return VA. u[I[end ]] = u_I
2527end
2628
29+ for (type, N_expr) in [
30+ (Broadcast. Broadcasted{<: VectorOfArrayStyle }, :(narrays (bc))),
31+ (Broadcast. Broadcasted{<: Broadcast.DefaultArrayStyle }, :(length (dest. u)))
32+ ]
33+ @eval @inline function Base. copyto! (dest:: VectorOfStructArray ,
34+ bc:: $type )
35+ bc = Broadcast. flatten (bc)
36+ N = $ N_expr
37+ @inbounds for i in 1 : N
38+ dest_i = dest[:, i]
39+ if dest_i isa AbstractArray
40+ if ArrayInterface. ismutable (dest_i)
41+ copyto! (dest_i, unpack_voa (bc, i))
42+ else
43+ unpacked = unpack_voa (bc, i)
44+ arr_type = StaticArraysCore. similar_type (dest_i)
45+ dest_i = if length (unpacked) == 1 && length (dest_i) == 1
46+ arr_type (unpacked[1 ])
47+ elseif length (unpacked) == 1
48+ fill (copy (unpacked), arr_type)
49+ else
50+ arr_type (unpacked[j] for j in eachindex (unpacked))
51+ end
52+ end
53+ else
54+ dest_i = copy (unpack_voa (bc, i))
55+ end
56+ dest[:, i] = dest_i
57+ end
58+ dest
59+ end
60+ end
61+
2762end
0 commit comments