@@ -237,6 +237,159 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
237237 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
238238}
239239
240+ static bool matchingIteratorTypes (ArrayRef<utils::IteratorType> iteratorTypes,
241+ ArrayRef<mlir::utils::IteratorType> expectedIteratorTypes) {
242+ if (iteratorTypes.size () != expectedIteratorTypes.size ()) return false ;
243+ for (auto [orig, expected] : llvm::zip_equal (iteratorTypes, expectedIteratorTypes)) {
244+ if (orig != expected) return false ;
245+ }
246+ return true ;
247+ }
248+
249+ static mlir::AffineExpr getAffineMapDim (ArrayAttr indexingMaps,
250+ uint32_t mapIndex, uint32_t dimIndex) {
251+ auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue ();
252+ // uint32_t nResults = affineMap.getNumResults();
253+ // llvm::outs()<<affineMap<<"\n";
254+ // llvm::outs()<<"Total result = "<<affineMap.getNumResults()<<"\n";
255+ // llvm::outs()<<"N = "<<nResults<<", dimIndex = "<<dimIndex<<"\n";
256+ // llvm::outs().flush();
257+ return affineMap.getResult (dimIndex);
258+ }
259+
260+ static std::string inferBasedOnRank2ConvIteratorTypes (GenericOp genericOp) {
261+ SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray ();
262+ SmallVector<utils::IteratorType> expectedIteratorTypes = {
263+ utils::IteratorType::parallel, utils::IteratorType::reduction
264+ };
265+
266+ if (matchingIteratorTypes (iteratorTypes, expectedIteratorTypes))
267+ return " linalg.conv_1d" ;
268+ return " " ;
269+ }
270+
271+ static std::string inferBasedOnRank4ConvIteratorTypes (GenericOp genericOp) {
272+ ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
273+ if (indexingMaps.size () != 3 ) return " " ;
274+ SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray ();
275+ // Conv 1D
276+ // depthwise_conv_1d_ncw_cw
277+ // depthwise_conv_1d_nwc_wc
278+ // ["parallel", "parallel", "parallel", "reduction"]
279+ SmallVector<utils::IteratorType> expectedIteratorTypes = {
280+ utils::IteratorType::parallel, utils::IteratorType::parallel,
281+ utils::IteratorType::parallel, utils::IteratorType::reduction
282+ };
283+ // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2;
284+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 2 ;
285+ if (matchingIteratorTypes (iteratorTypes, expectedIteratorTypes)) {
286+ if (getAffineMapDim (indexingMaps, fIndex , 0 ) == getAffineMapDim (indexingMaps, oIndex, 1 ))
287+ return " linalg.depthwise_conv_1d_ncw_cw" ;
288+ else if (getAffineMapDim (indexingMaps, fIndex , 1 ) == getAffineMapDim (indexingMaps, oIndex, 2 ))
289+ return " linalg.depthwise_conv_1d_nwc_wc" ;
290+ }
291+
292+ //
293+ expectedIteratorTypes[2 ] = utils::IteratorType::reduction;
294+ if (matchingIteratorTypes (iteratorTypes, expectedIteratorTypes)) {
295+ return " linalg.conv_2d" ;
296+ }
297+ return " " ;
298+ }
299+
300+ static std::string inferBasedOnRank5ConvIteratorTypes (GenericOp genericOp) {
301+ ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
302+ if (indexingMaps.size () != 3 ) return " " ;
303+ SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray ();
304+ // "parallel", "parallel", "parallel", "reduction", "reduction"]
305+ SmallVector<utils::IteratorType> expectedIteratorTypes = {
306+ utils::IteratorType::parallel, utils::IteratorType::parallel,
307+ utils::IteratorType::parallel, utils::IteratorType::parallel,
308+ utils::IteratorType::reduction
309+ };
310+ if (matchingIteratorTypes (iteratorTypes, expectedIteratorTypes))
311+ return " linalg.depthwise_conv_1d_nwc_wcm" ;
312+
313+ expectedIteratorTypes[3 ] = utils::IteratorType::reduction;
314+ // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2;
315+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 2 ;
316+ if (matchingIteratorTypes (iteratorTypes, expectedIteratorTypes)) {
317+ if (getAffineMapDim (indexingMaps, fIndex , 2 ) == getAffineMapDim (indexingMaps, oIndex, 2 ))
318+ return " linalg.conv_1d_nwc_wcf" ;
319+ else if (getAffineMapDim (indexingMaps, fIndex , 0 ) == getAffineMapDim (indexingMaps, oIndex, 1 ))
320+ return " linalg.conv_1d_ncw_fcw" ;
321+ }
322+ return " " ;
323+ }
324+
325+ static std::string inferBasedOnRank7ConvIteratorTypes (GenericOp genericOp) {
326+ SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray ();
327+ SmallVector<utils::IteratorType> expectedIteratorTypes = {
328+ utils::IteratorType::parallel, utils::IteratorType::reduction
329+ };
330+ if (matchingIteratorTypes (iteratorTypes, expectedIteratorTypes))
331+ return " linalg.conv_1d" ;
332+ return " " ;
333+ }
334+
335+ static std::string inferBasedOnRank8ConvIteratorTypes (GenericOp genericOp) {
336+ SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray ();
337+ SmallVector<utils::IteratorType> expectedIteratorTypes = {
338+ utils::IteratorType::parallel, utils::IteratorType::reduction
339+ };
340+ if (matchingIteratorTypes (iteratorTypes, expectedIteratorTypes))
341+ return " linalg.conv_1d" ;
342+ return " " ;
343+ }
344+
345+ static std::string inferConvolutionKind (GenericOp genericOp) {
346+ SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray ();
347+ unsigned totalIterators = iteratorTypes.size ();
348+ switch (totalIterators) {
349+ case 2 :
350+ return inferBasedOnRank2ConvIteratorTypes (genericOp);
351+ case 4 :
352+ return inferBasedOnRank4ConvIteratorTypes (genericOp);
353+ case 5 :
354+ return inferBasedOnRank5ConvIteratorTypes (genericOp);
355+ case 7 :
356+ return inferBasedOnRank7ConvIteratorTypes (genericOp);
357+ case 8 :
358+ return inferBasedOnRank8ConvIteratorTypes (genericOp);
359+ }
360+ return " " ;
361+ }
362+
363+ // Converts linalg.generic to named linalg.*conv* where possible.
364+ static FailureOr<LinalgOp> specializeLinalgConvolutions (RewriterBase &rewriter,
365+ GenericOp genericOp) {
366+ std::string convKind = inferConvolutionKind (genericOp);
367+ if (convKind == " " ) return failure ();
368+ SmallVector<Value> inputs = genericOp.getDpsInputs ();
369+ ValueRange outputs = genericOp.getDpsInits ();
370+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray ();
371+ SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics ()
372+ ? TypeRange (ValueRange (outputs))
373+ : TypeRange{};
374+ LinalgOp namedOp;
375+ if (convKind == " linalg.conv_1d" ) {
376+ namedOp = rewriter.replaceOpWithNewOp <linalg::Conv1DOp>(genericOp, resultTypes, inputs, outputs);
377+ } else if (convKind == " linalg.conv_1d_nwc_wcf" ) {
378+ namedOp = rewriter.replaceOpWithNewOp <linalg::Conv1DNwcWcfOp>(genericOp, resultTypes, inputs, outputs);
379+ } else if (convKind == " linalg.conv_1d_ncw_fcw" ) {
380+ namedOp = rewriter.replaceOpWithNewOp <linalg::Conv1DNcwFcwOp>(genericOp, resultTypes, inputs, outputs);
381+ } else if (convKind == " linalg.depthwise_conv_1d_ncw_cw" ) {
382+ namedOp = rewriter.replaceOpWithNewOp <linalg::DepthwiseConv1DNcwCwOp>(genericOp, resultTypes, inputs, outputs);
383+ } else if (convKind == " linalg.depthwise_conv_1d_nwc_wc" ) {
384+ namedOp = rewriter.replaceOpWithNewOp <linalg::DepthwiseConv1DNwcWcOp>(genericOp, resultTypes, inputs, outputs);
385+ } else if (convKind == " linalg.conv_2d" ) {
386+ namedOp = rewriter.replaceOpWithNewOp <linalg::Conv2DOp>(genericOp, resultTypes, inputs, outputs);
387+ }
388+ return namedOp;
389+
390+ return failure ();
391+ }
392+
240393} // namespace
241394
242395// ===----------------------------------------------------------------------===//
@@ -316,6 +469,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
316469 if (isaContractionOpInterface (genericOp)) {
317470 return specializeLinalgContractions (rewriter, genericOp);
318471 }
472+
473+ // Convolution - e.g. *conv*
474+ if (isaConvolutionOpInterface (genericOp)) {
475+ return specializeLinalgConvolutions (rewriter, genericOp);
476+ }
319477 return failure ();
320478}
321479
0 commit comments