Skip to content

Commit feba0ca

Browse files
ptiedewsmosesgithub-actions[bot]avik-pal
authored
feat: better StructArray & StaticArray support (#2546)
* Draft to figure out better StructArray support * Simplify and generalize structarray type conversion * Start adding StaticArray support * Add StaticArray support and tweak elem_apply_while_loop to select correct container type * Revert tracing.jl * Remove info debug * Remove get_ith * Add _copyto! * format * Fix broken test and add new tests * format * add StaticArrays * Add LinearAlgebra * Remove unused function * Reuse the known destination for while loop if possible * Update ext/ReactantStructArraysExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Proposed improved support for SArrays * fix dumb mistake * Add additional changes for StaticArrays * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Cleanup * Update * Update * Format * Fix for code review * Add comments * So dumb * Correct comment in overloaded_mul function Fix comment typo in overloaded_mul function. * Update to remove anonymous functions * Update * Add a complex test * Update --------- Co-authored-by: Billy Moses <wmoses@google.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Avik Pal <avikpal@mit.edu>
1 parent 0fa8c06 commit feba0ca

File tree

7 files changed

+175
-45
lines changed

7 files changed

+175
-45
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
6464
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
6565
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
6666
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
67+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
6768
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
6869
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
6970
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
@@ -91,6 +92,7 @@ ReactantOffsetArraysExt = "OffsetArrays"
9192
ReactantOneHotArraysExt = "OneHotArrays"
9293
ReactantPythonCallExt = "PythonCall"
9394
ReactantRandom123Ext = "Random123"
95+
ReactantStaticArraysExt = "StaticArrays"
9496
ReactantSparseArraysExt = "SparseArrays"
9597
ReactantSpecialFunctionsExt = "SpecialFunctions"
9698
ReactantStatisticsExt = "Statistics"
@@ -149,6 +151,7 @@ Setfield = "1.1.2"
149151
Sockets = "1.10"
150152
SparseArrays = "1.10"
151153
SpecialFunctions = "2.4"
154+
StaticArrays = "1"
152155
StableRNGs = "1.0.4"
153156
Statistics = "1.10"
154157
StructArrays = "0.7.2"

ext/ReactantStaticArraysExt.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
module ReactantStaticArraysExt
2+
3+
using Reactant
4+
import Reactant.TracedRArrayOverrides: overloaded_map, overloaded_mapreduce
5+
import Reactant.TracedLinearAlgebra: overloaded_mul
6+
7+
using StaticArrays: SArray, StaticArray
8+
9+
const SAReact{Sz,T} = StaticArray{Sz,T} where {Sz<:Tuple,T<:Reactant.TracedRNumber}
10+
11+
Base.@nospecializeinfer function Reactant.traced_type_inner(
12+
@nospecialize(FA::Type{SArray{S,T,N,L}}),
13+
seen,
14+
mode::Reactant.TraceMode,
15+
@nospecialize(track_numbers::Type),
16+
@nospecialize(ndevices),
17+
@nospecialize(runtime)
18+
) where {S,T,N,L}
19+
T_traced = Reactant.traced_type_inner(T, seen, mode, track_numbers, ndevices, runtime)
20+
return SArray{S,T_traced,N,L}
21+
end
22+
23+
function Reactant.materialize_traced_array(x::SAReact)
24+
return x
25+
end
26+
27+
# We don't want to overload map on StaticArrays since it is likely better to just unroll things
28+
overloaded_map(f, a::SAReact, rest::SAReact...) = f.(a, rest...)
29+
overloaded_mapreduce(f, op, a::SAReact; kwargs...) = mapreduce(f, op, a, kwargs...)
30+
31+
function overloaded_mul(A::SAReact, B::SAReact, alpha::Number=true, beta::Number=false)
32+
# beta is not supported since it is zero by default in Reactant
33+
# (similar is zero'd automatically for TracedRArrays)
34+
C = A * B
35+
if !(alpha isa Reactant.TracedRNumber) && isone(alpha)
36+
return C
37+
end
38+
return C .* alpha
39+
end
40+
41+
end

ext/ReactantStructArraysExt.jl

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ function Base.copy(
4040
end
4141

4242
function Reactant.broadcast_to_size(arg::StructArray{T}, rsize) where {T}
43-
new = [broadcast_to_size(c, rsize) for c in components(arg)]
44-
return StructArray{T}(NamedTuple(Base.propertynames(arg) .=> new))
43+
new = Tuple((broadcast_to_size(c, rsize) for c in components(arg)))
44+
return StructArray{T}(new)
4545
end
4646

4747
function Base.copyto!(
@@ -53,12 +53,49 @@ function Base.copyto!(
5353
bc = Broadcast.preprocess(dest, bc)
5454

5555
args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)
56-
5756
res = Reactant.TracedUtils.elem_apply_via_while_loop(bc.f, args...)
57+
copyto!(dest, res)
58+
59+
return dest
60+
end
5861

62+
function Reactant.TracedRArrayOverrides._copyto!(
63+
dest::StructArray, bc::Base.Broadcast.Broadcasted{<:AbstractReactantArrayStyle}
64+
)
65+
return copyto!(dest, bc)
66+
end
67+
68+
function Base.copyto!(
69+
dest::Reactant.TracedRArray, bc::Broadcasted{StructArrayStyle{S,N}}
70+
) where {S<:AbstractReactantArrayStyle,N}
71+
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
72+
isempty(dest) && return dest
73+
74+
bc = Broadcast.preprocess(dest, bc)
75+
args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)
76+
res = Reactant.TracedUtils.elem_apply_via_while_loop(bc.f, args...)
5977
return copyto!(dest, res)
6078
end
6179

80+
function alloc_sarr(bc, T)
81+
# Short circuit for Complex since in Reactant they are just a regular number
82+
T <: Complex && return similar(bc, T)
83+
asa = Base.Fix1(alloc_sarr, bc)
84+
if StructArrays.isnonemptystructtype(T)
85+
return StructArrays.buildfromschema(asa, T)
86+
else
87+
return similar(bc, T)
88+
end
89+
end
90+
91+
function Base.similar(
92+
bc::Broadcasted{StructArrayStyle{S,N}}, ::Type{ElType}
93+
) where {S<:AbstractReactantArrayStyle,N,ElType}
94+
bc′ = convert(Broadcasted{S}, bc)
95+
# It is possible that we have multiple broadcasted arguments
96+
return alloc_sarr(bc′, ElType)
97+
end
98+
6299
Base.@propagate_inbounds function StructArrays._getindex(
63100
x::StructArray{T}, I::Vararg{TracedRNumber{<:Integer}}
64101
) where {T}
@@ -67,14 +104,42 @@ Base.@propagate_inbounds function StructArrays._getindex(
67104
return createinstance(T, get_ith(cols, I...)...)
68105
end
69106

107+
setstruct(col, val, I) = @inbounds Reactant.@allowscalar col[I] = val
108+
struct SetStruct{T}
109+
I::T
110+
end
111+
(s::SetStruct)(col, val) = setstruct(col, val, s.I)
112+
(s::SetStruct)(vals) = s(vals...)
113+
70114
Base.@propagate_inbounds function Base.setindex!(
71115
s::StructArray{T,<:Any,<:Any,Int}, vals, I::TracedRNumber{TI}
72116
) where {T,TI<:Integer}
73117
valsT = maybe_convert_elt(T, vals)
74-
foreachfield((col, val) -> (@inbounds col[I] = val), s, valsT)
118+
setter = SetStruct(I)
119+
foreachfield(setter, s, valsT)
75120
return s
76121
end
77122

123+
const MRarr = Union{Reactant.AnyTracedRArray,Reactant.RArray}
124+
getstruct(col, n, I) = @inbounds Reactant.@allowscalar col[n][I...]
125+
struct GetStruct{C,Idx}
126+
cols::C
127+
I::Idx
128+
end
129+
(g::GetStruct)(n) = getstruct(g.cols, n, g.I...)
130+
131+
function StructArrays.get_ith(cols::NamedTuple{N,<:NTuple{K,<:MRarr}}, I...) where {N,K}
132+
getter = GetStruct(cols, I)
133+
ith = ntuple(getter, Val(K))
134+
return ith
135+
end
136+
137+
function StructArrays.get_ith(cols::NTuple{K,<:MRarr}, I...) where {K}
138+
getter = GetStruct(cols, I)
139+
ith = ntuple(getter, Val(K))
140+
return ith
141+
end
142+
78143
Base.@nospecializeinfer function Reactant.traced_type_inner(
79144
@nospecialize(prev::Type{StructArray{ET,N,C,I}}),
80145
seen,
@@ -90,41 +155,15 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
90155
return StructArray{ET_traced,N,C_traced,index_type(fieldtypes(C_traced))}
91156
end
92157

93-
function Reactant.make_tracer(
94-
seen,
95-
@nospecialize(prev::StructArray{NT,N}),
96-
@nospecialize(path),
97-
mode;
98-
track_numbers=false,
99-
sharding=Reactant.Sharding.Sharding.NoSharding(),
100-
runtime=nothing,
101-
kwargs...,
102-
) where {NT<:NamedTuple,N}
103-
track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{})
104-
components = getfield(prev, :components)
105-
if mode == TracedToTypes
106-
push!(path, typeof(prev))
107-
for c in components
108-
make_tracer(seen, c, path, mode; track_numbers, sharding, runtime, kwargs...)
109-
end
110-
return nothing
111-
end
112-
traced_components = make_tracer(
113-
seen,
114-
components,
115-
append_path(path, 1),
116-
mode;
117-
track_numbers,
118-
sharding,
119-
runtime,
120-
kwargs...,
121-
)
122-
T_traced = traced_type(typeof(prev), Val(mode), track_numbers, sharding, runtime)
123-
return StructArray{first(T_traced.parameters)}(traced_components)
124-
end
125-
126158
@inline function Reactant.traced_getfield(@nospecialize(obj::StructArray), field)
127159
return Base.getfield(obj, field)
128160
end
129161

162+
# This is to tell StructArrays to leave these array types alone.
163+
StructArrays.staticschema(::Type{<:Reactant.AnyTracedRArray}) = NamedTuple{()}
164+
StructArrays.staticschema(::Type{<:Reactant.RArray}) = NamedTuple{()}
165+
StructArrays.staticschema(::Type{<:Reactant.RNumber}) = NamedTuple{()}
166+
# # Even though RNumbers we have fields we want them to be threated as empty structs
167+
StructArrays.isnonemptystructtype(::Type{<:Reactant.RNumber}) = false
168+
StructArrays.isnonemptystructtype(::Type{<:Reactant.TracedRArray}) = false
130169
end

src/TracedRArray.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,6 @@ end
383383
function _copyto!(dest::AnyTracedRArray, bc::Broadcasted)
384384
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
385385
isempty(dest) && return dest
386-
387386
bc = Broadcast.preprocess(dest, bc)
388387

389388
args = (Reactant.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)

src/TracedUtils.jl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,13 @@ function __elem_apply_loop_condition(idx_ref, fn_ref::F, res_ref, args_ref, L_re
11151115
return idx_ref[] < L_ref[]
11161116
end
11171117

1118+
struct RefFillVector{T}
1119+
data::T
1120+
end
1121+
1122+
Base.getindex(rv::RefFillVector, i) = rv.data[]
1123+
Base.broadcastable(x::RefFillVector) = x
1124+
11181125
function __elem_apply_loop_body(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) where {F}
11191126
args = args_ref[]
11201127
fn = fn_ref[]
@@ -1129,14 +1136,24 @@ function __elem_apply_loop_body(idx_ref, fn_ref::F, res_ref, args_ref, L_ref) wh
11291136
return nothing
11301137
end
11311138

1139+
scalar_arg(arg) = arg isa Base.RefValue || !(arg isa AbstractArray)
1140+
1141+
flattenarg(arg) = ReactantCore.materialize_traced_array(vec(arg))
1142+
flattenarg(arg::Ref) = RefFillVector(arg)
1143+
11321144
function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs}
1133-
@assert allequal(size.(args)) "All args must have the same size"
1134-
L = length(first(args))
1145+
non_ref_args = [arg for arg in args if !scalar_arg(arg)]
1146+
if !isempty(non_ref_args)
1147+
@assert allequal(size.(non_ref_args)) "All args must have the same size"
1148+
end
1149+
out_size = isempty(non_ref_args) ? () : size(first(non_ref_args))
1150+
L = isempty(non_ref_args) ? 1 : length(first(non_ref_args))
11351151
# flattening the tensors makes the auto-batching pass work nicer
1136-
flat_args = [ReactantCore.materialize_traced_array(vec(arg)) for arg in args]
1152+
flat_args = [flattenarg(arg) for arg in args]
11371153

11381154
# This wont be a mutating function so we can safely execute it once
1139-
res_tmp = @allowscalar(f([@allowscalar(arg[1]) for arg in flat_args]...))
1155+
scalar_seed_args = [@allowscalar(arg[1]) for arg in flat_args]
1156+
res_tmp = @allowscalar(f(scalar_seed_args...))
11401157

11411158
# TODO: perhaps instead of this logic, we should have
11421159
# `similar(::TracedRArray, TracedRNumber{T}) where T = similar(::TracedRArray, T)`
@@ -1146,7 +1163,12 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs}
11461163
else
11471164
typeof(res_tmp)
11481165
end
1149-
result = similar(first(flat_args), T_res, L)
1166+
1167+
# Before we selected the output container based on the first argument
1168+
# That doesn't work for cases when StructArrays are involved
1169+
# Since this is essentially a broadcast I'm reusing this machinery
1170+
bc = Base.Broadcast.Broadcasted(f, Tuple(args))
1171+
result = similar(bc, T_res)
11501172

11511173
ind_var = Ref(0)
11521174
f_ref = Ref(f)
@@ -1160,7 +1182,7 @@ function elem_apply_via_while_loop(f, args::Vararg{Any,Nargs}) where {Nargs}
11601182
(ind_var, f_ref, result_ref, args_ref, limit_ref),
11611183
)
11621184

1163-
return ReactantCore.materialize_traced_array(reshape(result, size(first(args))))
1185+
return ReactantCore.materialize_traced_array(reshape(result, out_size))
11641186
end
11651187

11661188
function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
4040
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
4141
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4242
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
43+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4344
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4445
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
4546
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

test/integration/structarrays.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using StructArrays, Reactant, Test
1+
using StructArrays, StaticArrays, Reactant, LinearAlgebra, Test
22

33
@testset "StructArray to_rarray and make_tracer" begin
44
x = StructArray(;
@@ -69,3 +69,28 @@ end
6969
@test component_ra component
7070
end
7171
end
72+
73+
@testset "structarray with static array broadcasting" begin
74+
trel(x) = tr.(x)
75+
s = StructArray{SMatrix{2,2,Float64,4}}((
76+
fill(1.0, 4), fill(2.0, 4), fill(3.0, 4), fill(4.0, 4)
77+
))
78+
sr = Reactant.to_rarray(s)
79+
out = @jit(trel(sr))
80+
@test out trel(s)
81+
@test out isa ConcreteRArray
82+
@test @jit(sum(sr)) sum(s)
83+
end
84+
85+
@testset "structarray with complex numbers" begin
86+
s = randn(64)
87+
88+
elcom(x) = complex(x, x)
89+
sr = Reactant.to_rarray(s)
90+
out = @jit(elcom.(sr))
91+
@test out elcom.(s)
92+
@test out isa ConcreteRArray
93+
@test @jit(sum(sr)) sum(s)
94+
end
95+
96+

0 commit comments

Comments
 (0)