Skip to content

Commit e5e4fea

Browse files
fix: broadcasting with OHA (#1353)
* fix: broadcasting with OHA * Update ext/ReactantOneHotArraysExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix: convert to Array --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 2a649d7 commit e5e4fea

File tree

4 files changed

+34
-2
lines changed

4 files changed

+34
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.121"
4+
version = "0.2.122"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/ReactantOneHotArraysExt.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,23 @@ function Reactant.traced_type_inner(
2020
return OneHotArrays.OneHotArray{T2,N,Np1,I2}
2121
end
2222

23+
# OneHotArray is a <: AbstractArray{Bool, M} so our usual dispatches don't work
24+
function Reactant.TracedUtils.broadcast_to_size(
25+
r::OneHotArrays.OneHotArray{T,N,Np1,<:Reactant.TracedRArray}, rsize
26+
) where {T,N,Np1}
27+
return Reactant.TracedUtils.broadcast_to_size(
28+
Reactant.TracedUtils.materialize_traced_array(r), rsize
29+
)
30+
end
31+
32+
function Reactant.TracedUtils.materialize_traced_array(r::OneHotArrays.OneHotArray)
33+
return reshape(r.indices, 1, size(r.indices)...) .== 1:(r.nlabels)
34+
end
35+
36+
function Base.Array(
37+
r::OneHotArrays.OneHotArray{T,N,Np1,<:Reactant.AbstractConcreteArray}
38+
) where {T,N,Np1}
39+
return Array(reshape(Array(r.indices), 1, size(r.indices)...) .== 1:(r.nlabels))
40+
end
41+
2342
end

src/TracedRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ function _setindex_linear!(a::TracedRArray{T,N}, v, indices::AbstractArray) wher
376376
res = Ops.scatter_setindex(
377377
a,
378378
scalar_index_to_cartesian(vec(indices), size(a)),
379-
materialize_traced_array(vec(v)),
379+
TracedUtils.promote_to(TracedRArray{T,1}, materialize_traced_array(vec(v))),
380380
)
381381
set_mlir_data!(a, get_mlir_data(res))
382382
return a

test/integration/onehotarrays.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,16 @@ using Reactant, Test, OneHotArrays, Random
99
res = a * m
1010
@test convert(Array, r_res) res
1111
end
12+
13+
@testset "broadcasting" begin
14+
m = onehotbatch([10, 20, 30, 10, 10], 10:10:40)
15+
r_m = Reactant.to_rarray(m)
16+
x = rand(Float32, 4, 5)
17+
x_ra = Reactant.to_rarray(x)
18+
19+
res_ra = @jit r_m .+ x_ra
20+
res = m .+ x
21+
@test res_ra res
22+
23+
@test Array(r_m) isa Matrix{Bool}
24+
end

0 commit comments

Comments
 (0)