1
1
module ReactantNNlibExt
2
2
3
3
using NNlib
4
- using Reactant
4
+ using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array
5
5
6
6
for (jlop, hloop) in (
7
7
(:(NNlib. tanh_fast), :tanh ),
8
8
(:(NNlib. sigmoid_fast), :logistic ),
9
9
(:(NNlib. sigmoid), :logistic ),
10
10
)
11
- @eval function $ (jlop)(x:: Reactant. TracedRArray{T,0} ) where {T}
12
- return Reactant . TracedRArray {T,0} (
11
+ @eval function $ (jlop)(x:: TracedRArray{T,0} ) where {T}
12
+ return TracedRArray {T,0} (
13
13
(),
14
14
Reactant. MLIR. IR. result (
15
15
Reactant. MLIR. Dialects. stablehlo.$ (hloop)(x. mlir_data), 1
@@ -19,18 +19,16 @@ for (jlop, hloop) in (
19
19
end
20
20
end
21
21
22
- NNlib. relu (x:: Reactant. TracedRArray{T,0} ) where {T} = max (x, zero (T))
22
+ NNlib. relu (x:: TracedRArray{T,0} ) where {T} = max (x, zero (T))
23
23
24
- function NNlib. gelu (x:: Reactant. TracedRArray{T,0} ) where {T}
24
+ function NNlib. gelu (x:: TracedRArray{T,0} ) where {T}
25
25
α = T (0.044715 )
26
26
λλ = T (√ (8 / π))
27
27
return x * sigmoid (λλ * x * muladd (x^ 2 , α, one (T)))
28
28
end
29
29
30
30
# TODO handle non finite cases
31
- function NNlib. softmax! (
32
- out:: Reactant.TracedRArray{T,N} , x:: AbstractArray ; dims= 1
33
- ) where {T,N}
31
+ function NNlib. softmax! (out:: TracedRArray{T,N} , x:: AbstractArray ; dims= 1 ) where {T,N}
34
32
max_ = NNlib. fast_maximum (x; dims)
35
33
# if all(isfinite, max_)
36
34
@fastmath out .= exp .(x .- max_)
@@ -43,8 +41,11 @@ function NNlib.softmax!(
43
41
end
44
42
45
43
function NNlib. conv (
46
- x:: Reactant.TracedRArray {T,N} , W:: Reactant.TracedRArray {T} , cdims:: DenseConvDims
44
+ x:: AnyTracedRArray {T,N} , W:: AnyTracedRArray {T} , cdims:: DenseConvDims
47
45
) where {T,N}
46
+ x = materialize_traced_array (x)
47
+ W = materialize_traced_array (W)
48
+
48
49
kernel_size = NNlib. kernel_size (cdims)
49
50
padding = NNlib. padding (cdims)
50
51
stride = NNlib. stride (cdims)
@@ -119,10 +120,12 @@ function NNlib.conv(
119
120
batch_group_count= 1 ,
120
121
)
121
122
122
- return Reactant . TracedRArray {T,N} ((), Reactant. MLIR. IR. result (conv), output_shape)
123
+ return TracedRArray {T,N} ((), Reactant. MLIR. IR. result (conv), output_shape)
123
124
end
124
125
125
- function reduce_window (f, x:: Reactant.TracedRArray{T,N} , pdims; init) where {T,N}
126
+ function reduce_window (f, x:: AnyTracedRArray{T,N} , pdims; init) where {T,N}
127
+ x = materialize_traced_array (x)
128
+
126
129
num_spatial_dims = N - 2
127
130
input_spatial_dims = 1 : num_spatial_dims
128
131
@@ -185,18 +188,16 @@ function reduce_window(f, x::Reactant.TracedRArray{T,N}, pdims; init) where {T,N
185
188
body,
186
189
)
187
190
188
- return Reactant. TracedRArray {T,N} (
189
- (), Reactant. MLIR. IR. result (reduction), size (result_type)
190
- )
191
+ return TracedRArray {T,N} ((), Reactant. MLIR. IR. result (reduction), size (result_type))
191
192
end
192
193
193
- function NNlib. maxpool (x:: Reactant.TracedRArray {T} , pdims:: NNlib.PoolDims ) where {T}
194
+ function NNlib. maxpool (x:: AnyTracedRArray {T} , pdims:: NNlib.PoolDims ) where {T}
194
195
return reduce_window (
195
196
Reactant. MLIR. Dialects. stablehlo. maximum, x, pdims; init= typemin (T)
196
197
)
197
198
end
198
199
199
- function NNlib. meanpool (x:: Reactant.TracedRArray {T} , pdims:: NNlib.PoolDims ) where {T}
200
+ function NNlib. meanpool (x:: AnyTracedRArray {T} , pdims:: NNlib.PoolDims ) where {T}
200
201
numel = prod (NNlib. kernel_size (pdims))
201
202
return reduce_window (Reactant. MLIR. Dialects. stablehlo. add, x, pdims; init= zero (T)) ./
202
203
T (numel)
0 commit comments