@@ -109,4 +109,84 @@ function NNlib.conv(
109
109
return Reactant. TracedRArray {T,N} ((), Reactant. MLIR. IR. result (conv), output_shape)
110
110
end
111
111
112
+ function reduce_window (f, x:: Reactant.TracedRArray{T,N} , pdims; init) where {T,N}
113
+ num_spatial_dims = N - 2
114
+ input_spatial_dims = 1 : num_spatial_dims
115
+
116
+ dilation = NNlib. dilation (pdims)
117
+ kernel_size = NNlib. kernel_size (pdims)
118
+ stride = NNlib. stride (pdims)
119
+ padding = NNlib. padding (pdims)
120
+
121
+ window_dimensions = [kernel_size... , 1 , 1 ]
122
+ window_strides = [stride... , 1 , 1 ]
123
+ window_dilations = [dilation... , 1 , 1 ]
124
+
125
+ output_spatial_shapes = map (input_spatial_dims) do i
126
+ K = kernel_size[i]
127
+ pl, pr = padding[2 i - 1 ], padding[2 i]
128
+ d = dilation[i]
129
+ s = stride[i]
130
+
131
+ (size (x, i) + pl + pr - d * (K - 1 ) - 1 ) ÷ s + 1
132
+ end
133
+
134
+ padding = Reactant. MLIR. IR. DenseElementsAttribute (
135
+ reshape ([padding... , 0 , 0 , 0 , 0 ], (N, 2 ))
136
+ )
137
+
138
+ output_shape = (output_spatial_shapes... , size (x, N - 1 ), size (x, N))
139
+ result_type = Reactant. MLIR. IR. TensorType (output_shape, Reactant. MLIR. IR. Type (T))
140
+
141
+ unranked = Reactant. MLIR. IR. TensorType ((), eltype (Reactant. MLIR. IR. type (x. mlir_data)))
142
+ body =
143
+ let body = Reactant. MLIR. IR. Region (),
144
+ loc = Reactant. MLIR. IR. Location (),
145
+ block = Reactant. MLIR. IR. Block ([unranked, unranked], [loc, loc])
146
+
147
+ Reactant. MLIR. IR. block! (block) do
148
+ red = f (
149
+ Reactant. MLIR. IR. argument (block, 1 ),
150
+ Reactant. MLIR. IR. argument (block, 2 );
151
+ result= nothing ,
152
+ )
153
+ Reactant. MLIR. Dialects. stablehlo. return_ ([Reactant. MLIR. IR. result (red)])
154
+ end
155
+ push! (body, block)
156
+
157
+ body
158
+ end
159
+
160
+ attr = fill (Reactant. MLIR. IR. Attribute (init), unranked)
161
+ init_value = Reactant. MLIR. IR. result (
162
+ Reactant. MLIR. Dialects. stablehlo. constant (; value= attr)
163
+ )
164
+ reduction = Reactant. MLIR. Dialects. stablehlo. reduce_window (
165
+ [x. mlir_data],
166
+ [init_value];
167
+ result_0= [result_type],
168
+ window_dimensions,
169
+ window_strides,
170
+ window_dilations,
171
+ padding,
172
+ body,
173
+ )
174
+
175
+ return Reactant. TracedRArray {T,N} (
176
+ (), Reactant. MLIR. IR. result (reduction), size (result_type)
177
+ )
178
+ end
179
+
180
+ function NNlib. maxpool (x:: Reactant.TracedRArray{T} , pdims:: NNlib.PoolDims ) where {T}
181
+ return reduce_window (
182
+ Reactant. MLIR. Dialects. stablehlo. maximum, x, pdims; init= typemin (T)
183
+ )
112
184
end
185
+
186
+ function NNlib. meanpool (x:: Reactant.TracedRArray{T} , pdims:: NNlib.PoolDims ) where {T}
187
+ numel = prod (NNlib. kernel_size (pdims))
188
+ return reduce_window (Reactant. MLIR. Dialects. stablehlo. add, x, pdims; init= zero (T)) ./
189
+ T (numel)
190
+ end
191
+
192
+ end # module ReactantNNlibExt
0 commit comments