Skip to content

Commit 2f661e1

Browse files
authored
Compile NNlib.maxpool and NNlib.meanpool (#102)
1 parent aef81ea commit 2f661e1

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,84 @@ function NNlib.conv(
109109
return Reactant.TracedRArray{T,N}((), Reactant.MLIR.IR.result(conv), output_shape)
110110
end
111111

112+
function reduce_window(f, x::Reactant.TracedRArray{T,N}, pdims; init) where {T,N}
113+
num_spatial_dims = N - 2
114+
input_spatial_dims = 1:num_spatial_dims
115+
116+
dilation = NNlib.dilation(pdims)
117+
kernel_size = NNlib.kernel_size(pdims)
118+
stride = NNlib.stride(pdims)
119+
padding = NNlib.padding(pdims)
120+
121+
window_dimensions = [kernel_size..., 1, 1]
122+
window_strides = [stride..., 1, 1]
123+
window_dilations = [dilation..., 1, 1]
124+
125+
output_spatial_shapes = map(input_spatial_dims) do i
126+
K = kernel_size[i]
127+
pl, pr = padding[2i - 1], padding[2i]
128+
d = dilation[i]
129+
s = stride[i]
130+
131+
(size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
132+
end
133+
134+
padding = Reactant.MLIR.IR.DenseElementsAttribute(
135+
reshape([padding..., 0, 0, 0, 0], (N, 2))
136+
)
137+
138+
output_shape = (output_spatial_shapes..., size(x, N - 1), size(x, N))
139+
result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T))
140+
141+
unranked = Reactant.MLIR.IR.TensorType((), eltype(Reactant.MLIR.IR.type(x.mlir_data)))
142+
body =
143+
let body = Reactant.MLIR.IR.Region(),
144+
loc = Reactant.MLIR.IR.Location(),
145+
block = Reactant.MLIR.IR.Block([unranked, unranked], [loc, loc])
146+
147+
Reactant.MLIR.IR.block!(block) do
148+
red = f(
149+
Reactant.MLIR.IR.argument(block, 1),
150+
Reactant.MLIR.IR.argument(block, 2);
151+
result=nothing,
152+
)
153+
Reactant.MLIR.Dialects.stablehlo.return_([Reactant.MLIR.IR.result(red)])
154+
end
155+
push!(body, block)
156+
157+
body
158+
end
159+
160+
attr = fill(Reactant.MLIR.IR.Attribute(init), unranked)
161+
init_value = Reactant.MLIR.IR.result(
162+
Reactant.MLIR.Dialects.stablehlo.constant(; value=attr)
163+
)
164+
reduction = Reactant.MLIR.Dialects.stablehlo.reduce_window(
165+
[x.mlir_data],
166+
[init_value];
167+
result_0=[result_type],
168+
window_dimensions,
169+
window_strides,
170+
window_dilations,
171+
padding,
172+
body,
173+
)
174+
175+
return Reactant.TracedRArray{T,N}(
176+
(), Reactant.MLIR.IR.result(reduction), size(result_type)
177+
)
178+
end
179+
180+
function NNlib.maxpool(x::Reactant.TracedRArray{T}, pdims::NNlib.PoolDims) where {T}
181+
return reduce_window(
182+
Reactant.MLIR.Dialects.stablehlo.maximum, x, pdims; init=typemin(T)
183+
)
112184
end
185+
186+
function NNlib.meanpool(x::Reactant.TracedRArray{T}, pdims::NNlib.PoolDims) where {T}
187+
numel = prod(NNlib.kernel_size(pdims))
188+
return reduce_window(Reactant.MLIR.Dialects.stablehlo.add, x, pdims; init=zero(T)) ./
189+
T(numel)
190+
end
191+
192+
end # module ReactantNNlibExt

test/nn.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,15 @@ mean((out2[1, :] .> 0.5) .== truth) # accuracy 94% so far!
8484

8585
@test res_reactant res
8686
end
87+
88+
@testset "$f" for f in (NNlib.meanpool, NNlib.maxpool)
89+
img = randn(Float32, 224, 224, 3, 2)
90+
img_reactant = Reactant.ConcreteRArray(img)
91+
92+
f_reactant = Reactant.compile(f, (img_reactant, (3, 3)))
93+
94+
res_reactant = f_reactant(img_reactant, (3, 3))
95+
res = f(img, (3, 3))
96+
97+
@test res_reactant res
98+
end

0 commit comments

Comments
 (0)