Skip to content

Commit af21f68

Browse files
committed
feat: support NNlib
1 parent fbcb5bb commit af21f68

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
module ReactantNNlibExt
22

33
using NNlib
4-
using Reactant
4+
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array
55

66
for (jlop, hloop) in (
77
(:(NNlib.tanh_fast), :tanh),
88
(:(NNlib.sigmoid_fast), :logistic),
99
(:(NNlib.sigmoid), :logistic),
1010
)
11-
@eval function $(jlop)(x::Reactant.TracedRArray{T,0}) where {T}
12-
return Reactant.TracedRArray{T,0}(
11+
@eval function $(jlop)(x::TracedRArray{T,0}) where {T}
12+
return TracedRArray{T,0}(
1313
(),
1414
Reactant.MLIR.IR.result(
1515
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
@@ -19,18 +19,16 @@ for (jlop, hloop) in (
1919
end
2020
end
2121

22-
NNlib.relu(x::Reactant.TracedRArray{T,0}) where {T} = max(x, zero(T))
22+
NNlib.relu(x::TracedRArray{T,0}) where {T} = max(x, zero(T))
2323

24-
function NNlib.gelu(x::Reactant.TracedRArray{T,0}) where {T}
24+
function NNlib.gelu(x::TracedRArray{T,0}) where {T}
2525
α = T(0.044715)
2626
λλ = T((8 / π))
2727
return x * sigmoid(λλ * x * muladd(x^2, α, one(T)))
2828
end
2929

3030
# TODO handle non finite cases
31-
function NNlib.softmax!(
32-
out::Reactant.TracedRArray{T,N}, x::AbstractArray; dims=1
33-
) where {T,N}
31+
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
3432
max_ = NNlib.fast_maximum(x; dims)
3533
#if all(isfinite, max_)
3634
@fastmath out .= exp.(x .- max_)
@@ -43,8 +41,11 @@ function NNlib.softmax!(
4341
end
4442

4543
function NNlib.conv(
46-
x::Reactant.TracedRArray{T,N}, W::Reactant.TracedRArray{T}, cdims::DenseConvDims
44+
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
4745
) where {T,N}
46+
x = materialize_traced_array(x)
47+
W = materialize_traced_array(W)
48+
4849
kernel_size = NNlib.kernel_size(cdims)
4950
padding = NNlib.padding(cdims)
5051
stride = NNlib.stride(cdims)
@@ -119,10 +120,12 @@ function NNlib.conv(
119120
batch_group_count=1,
120121
)
121122

122-
return Reactant.TracedRArray{T,N}((), Reactant.MLIR.IR.result(conv), output_shape)
123+
return TracedRArray{T,N}((), Reactant.MLIR.IR.result(conv), output_shape)
123124
end
124125

125-
function reduce_window(f, x::Reactant.TracedRArray{T,N}, pdims; init) where {T,N}
126+
function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
127+
x = materialize_traced_array(x)
128+
126129
num_spatial_dims = N - 2
127130
input_spatial_dims = 1:num_spatial_dims
128131

@@ -185,18 +188,16 @@ function reduce_window(f, x::Reactant.TracedRArray{T,N}, pdims; init) where {T,N
185188
body,
186189
)
187190

188-
return Reactant.TracedRArray{T,N}(
189-
(), Reactant.MLIR.IR.result(reduction), size(result_type)
190-
)
191+
return TracedRArray{T,N}((), Reactant.MLIR.IR.result(reduction), size(result_type))
191192
end
192193

193-
function NNlib.maxpool(x::Reactant.TracedRArray{T}, pdims::NNlib.PoolDims) where {T}
194+
function NNlib.maxpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
194195
return reduce_window(
195196
Reactant.MLIR.Dialects.stablehlo.maximum, x, pdims; init=typemin(T)
196197
)
197198
end
198199

199-
function NNlib.meanpool(x::Reactant.TracedRArray{T}, pdims::NNlib.PoolDims) where {T}
200+
function NNlib.meanpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
200201
numel = prod(NNlib.kernel_size(pdims))
201202
return reduce_window(Reactant.MLIR.Dialects.stablehlo.add, x, pdims; init=zero(T)) ./
202203
T(numel)

0 commit comments

Comments
 (0)