Skip to content

Commit d16ebec

Browse files
authored
Ext: add onehotarrays (#1278)
* Ext: add onehotarrays * fix * fix prints * chore: run fmt * refactor: directly call lu * Update Project.toml
1 parent a1e114b commit d16ebec

File tree

4 files changed

+38
-0
lines changed

4 files changed

+38
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
3636
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
3737
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3838
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
39+
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
3940
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
4041
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
4142
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -53,6 +54,7 @@ ReactantKernelAbstractionsExt = "KernelAbstractions"
5354
ReactantMPIExt = "MPI"
5455
ReactantNNlibExt = "NNlib"
5556
ReactantOffsetArraysExt = "OffsetArrays"
57+
ReactantOneHotArraysExt = "OneHotArrays"
5658
ReactantPythonCallExt = "PythonCall"
5759
ReactantRandom123Ext = "Random123"
5860
ReactantSpecialFunctionsExt = "SpecialFunctions"
@@ -80,6 +82,7 @@ LinearAlgebra = "1.10"
8082
MPI = "0.20"
8183
NNlib = "0.9.26"
8284
OffsetArrays = "1"
85+
OneHotArrays = "0.2.10"
8386
OrderedCollections = "1"
8487
PrecompileTools = "1.2"
8588
Preferences = "1.4"

ext/ReactantOneHotArraysExt.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module ReactantOneHotArraysExt
2+
3+
using OneHotArrays
4+
using Reactant
5+
6+
function Reactant.traced_type_inner(
7+
@nospecialize(_::Type{OneHotArrays.OneHotArray{T,N,Np1,I}}),
8+
seen,
9+
@nospecialize(mode::Reactant.TraceMode),
10+
@nospecialize(track_numbers::Type),
11+
@nospecialize(sharding),
12+
@nospecialize(runtime)
13+
) where {T,N,Np1,I}
14+
I2 = Reactant.traced_type_inner(I, seen, mode, track_numbers, sharding, runtime)
15+
T2 = if eltype(I2) <: Reactant.TracedRNumber && !(T <: Reactant.TracedRNumber)
16+
Reactant.TracedRNumber{T}
17+
else
18+
T
19+
end
20+
return OneHotArrays.OneHotArray{T2,N,Np1,I2}
21+
end
22+
23+
end

test/integration/onehotarrays.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using Reactant, Test, OneHotArrays, Random
2+
3+
@testset "OneHotArrays" begin
4+
m = onehotbatch([10, 20, 30, 10, 10], 10:10:40)
5+
r_m = Reactant.to_rarray(m)
6+
a = rand(100, 4)
7+
r_a = Reactant.to_rarray(a)
8+
r_res = @jit r_a * r_m
9+
res = a * m
10+
@test convert(Array, r_res) res
11+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
4444
@safetestset "KernelAbstractions" include("integration/kernelabstractions.jl")
4545
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
4646
@safetestset "OffsetArrays" include("integration/offsetarrays.jl")
47+
@safetestset "OneHotArrays" include("integration/onehotarrays.jl")
4748
@safetestset "AbstractFFTs" include("integration/fft.jl")
4849
@safetestset "SpecialFunctions" include("integration/special_functions.jl")
4950
@safetestset "Random" include("integration/random.jl")

0 commit comments

Comments
 (0)