Skip to content

Commit d8ffd0d

Browse files
feat: new scatter / gather optimization patterns (#1397)
* feat: new scatter / gather optimization patterns * fix: use iota * Update test/integration/onehotarrays.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 8ee65e1 commit d8ffd0d

File tree

4 files changed

+33
-12
lines changed

4 files changed

+33
-12
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.133"
4+
version = "0.2.134"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -90,7 +90,7 @@ PythonCall = "0.9"
9090
Random = "1.10"
9191
Random123 = "1.7"
9292
ReactantCore = "0.1.12"
93-
Reactant_jll = "0.0.200"
93+
Reactant_jll = "0.0.201"
9494
ScopedValues = "1.3.0"
9595
Scratch = "1.2"
9696
Sockets = "1.10"

ext/ReactantOneHotArraysExt.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module ReactantOneHotArraysExt
22

33
using OneHotArrays
44
using Reactant
5+
using Reactant: TracedRArray, TracedRNumber, TracedUtils, Ops
56

67
function Reactant.traced_type_inner(
78
@nospecialize(_::Type{OneHotArrays.OneHotArray{T,N,Np1,I}}),
@@ -21,16 +22,24 @@ function Reactant.traced_type_inner(
2122
end
2223

2324
# OneHotArray is a <: AbstractArray{Bool, M} so our usual dispatches don't work
24-
function Reactant.TracedUtils.broadcast_to_size(
25+
function TracedUtils.broadcast_to_size(
2526
r::OneHotArrays.OneHotArray{T,N,Np1,<:Reactant.TracedRArray}, rsize
2627
) where {T,N,Np1}
27-
return Reactant.TracedUtils.broadcast_to_size(
28-
Reactant.TracedUtils.materialize_traced_array(r), rsize
29-
)
28+
return TracedUtils.broadcast_to_size(TracedUtils.materialize_traced_array(r), rsize)
3029
end
3130

32-
function Reactant.TracedUtils.materialize_traced_array(r::OneHotArrays.OneHotArray)
33-
return reshape(r.indices, 1, size(r.indices)...) .== 1:(r.nlabels)
31+
function TracedUtils.materialize_traced_array(r::OneHotArrays.OneHotArray)
32+
indices = vec(r.indices)
33+
N = r.nlabels
34+
B = length(indices)
35+
36+
linear_indices =
37+
TracedUtils.promote_to(TracedRArray{Int64,ndims(r.indices)}, indices) .+
38+
Ops.iota(Int64, [B]; iota_dimension=1) .* N
39+
40+
z = Ops.fill(false, (N, B))
41+
z[linear_indices] = fill(true, length(linear_indices))
42+
return reshape(z, size(r))
3443
end
3544

3645
function Base.Array(

src/Compiler.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,8 @@ function optimization_passes(;
928928
"split_convolution_into_reverse_convolution",
929929
# TODO we want to enable but may cause an infinite compile time
930930
# "concat_to_onedim_dusslice",
931+
"scatter_multiply_simplify",
932+
"unary_elementwise_scatter_simplify",
931933
]
932934

933935
# constant prop patterns
@@ -980,6 +982,7 @@ function optimization_passes(;
980982
"concat_const_prop<1>($max_constant_threshold)",
981983
"dynamic_update_slice_const_prop($max_constant_threshold)",
982984
"scatter_update_computation_const_prop",
985+
"gather_const_prop",
983986
],
984987
)
985988

test/integration/onehotarrays.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,18 @@ end
1616
x = rand(Float32, 4, 5)
1717
x_ra = Reactant.to_rarray(x)
1818

19-
res_ra = @jit r_m .+ x_ra
20-
res = m .+ x
21-
@test res_ra res
19+
@testset "addition" begin
20+
res_ra = @jit r_m .+ x_ra
21+
res = m .+ x
22+
@test res_ra res
2223

23-
@test Array(r_m) isa Matrix{Bool}
24+
@test Array(r_m) isa Matrix{Bool}
25+
end
26+
27+
@testset "multiplication" begin
28+
# Broadcasting a multiplication has special passes
29+
res_ra = @jit r_m .* x_ra
30+
res = m .* x
31+
@test res_ra res
32+
end
2433
end

0 commit comments

Comments
 (0)