@@ -237,106 +237,6 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
237237 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
238238}
239239
240- // / Utility to match block body for linalg.pool* ops.
241- template <typename ... OpTypes>
242- static bool bodyMatcherForPoolOps (Value yieldVal, Block *body) {
243- Operation *defOp = yieldVal.getDefiningOp ();
244- // if (!defOp) return false;
245- if (!(isa_and_present<OpTypes>(defOp) || ...)) return false ;
246-
247- BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand (0 ));
248- BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand (1 ));
249- if (!lhsArg || !rhsArg) return false ;
250- return true ;
251- }
252-
253- static bool bodyMatcherForMaxSignedPoolOps (Value yieldVal, Block *body) {
254- return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
255- }
256-
257- static bool bodyMatcherForMaxUnsignedPoolOps (Value yieldVal, Block *body) {
258- return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
259- }
260-
261- static bool bodyMatcherForMinSignedPoolOps (Value yieldVal, Block *body) {
262- return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
263- }
264-
265- static bool bodyMatcherForMinUnsignedPoolOps (Value yieldVal, Block *body) {
266- return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
267- }
268-
269- static bool bodyMatcherForSumPoolOps (Value yieldVal, Block *body) {
270- return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
271- }
272-
273- static mlir::AffineExpr getAffineMapDim (ArrayAttr indexingMaps,
274- uint32_t mapIndex, uint32_t dimIndex) {
275- auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue ();
276- if (dimIndex < affineMap.getNumResults ())
277- return affineMap.getResult (dimIndex);
278- return nullptr ;
279- }
280-
281- // Check if `expr` is either:
282- // - a dimension expr alone (implying *1), or
283- // - a multiplication of dimension expr by constant.
284- bool isDimTimesConstantOrDimOnly (AffineExpr expr, AffineExpr &dim, int64_t &constantValue) {
285- if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
286- dim = dExpr;
287- constantValue = 1 ;
288- return true ;
289- }
290-
291- auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
292- if (!mulExpr || mulExpr.getKind () != AffineExprKind::Mul)
293- return false ;
294-
295- AffineExpr lhs = mulExpr.getLHS ();
296- AffineExpr rhs = mulExpr.getRHS ();
297-
298- if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
299- if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
300- dim = dExpr;
301- constantValue = cst.getValue ();
302- return true ;
303- }
304- }
305- if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
306- if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
307- dim = dExpr;
308- constantValue = cst.getValue ();
309- return true ;
310- }
311- }
312- return false ;
313- }
314-
315- bool matchConvDimAddExprPattern (ArrayAttr indexingMaps, unsigned iDim, unsigned fDim , unsigned oDim) {
316- unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
317- AffineExpr inpExpr = getAffineMapDim (indexingMaps, iIndex, iDim);
318- auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
319- if (!addExpr || addExpr.getKind () != AffineExprKind::Add)
320- return false ;
321-
322- AffineExpr dim0, dim1;
323- // TODO(Abhishek-Varma): Use this information in specialize.cpp.
324- int64_t c0, c1;
325-
326- if (isDimTimesConstantOrDimOnly (addExpr.getLHS (), dim0, c0) &&
327- isDimTimesConstantOrDimOnly (addExpr.getRHS (), dim1, c1)) {
328- // Pattern matched with dims and constants extracted.
329- AffineExpr fExpr = getAffineMapDim (indexingMaps, fIndex , fDim );
330- AffineExpr oExpr = getAffineMapDim (indexingMaps, oIndex, oDim);
331- return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr));
332- }
333- return false ;
334- }
335-
336- bool matchConvDimExprPattern (ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
337- return getAffineMapDim (indexingMaps, aIndex, aDim) == getAffineMapDim (indexingMaps, bIndex, bDim);
338- }
339-
340240static std::string inferBasedOnRank2ConvIteratorTypes (GenericOp genericOp) {
341241 if (isaConv1DOp (genericOp)) return " linalg.conv_1d" ;
342242 return " " ;
@@ -349,41 +249,16 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
349249 return " linalg.depthwise_conv_1d_nwc_wc" ;
350250 if (isaConv2DOp (genericOp))
351251 return " linalg.conv_2d" ;
352-
353- ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
354- unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
355- Block *body = genericOp.getBlock ();
356- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
357- Value yieldVal = yieldOp.getOperand (0 );
358- // pooling_ncw_max
359- // pooling_ncw_sum
360- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
361- // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
362- // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
363- if (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
364- matchConvDimExprPattern (indexingMaps, iIndex, 1 , oIndex, 1 ) &&
365- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 0 , /* oDim=*/ 2 )) {
366- if (bodyMatcherForMaxSignedPoolOps (yieldVal, body))
367- return " linalg.pooling_ncw_max" ;
368- if (bodyMatcherForSumPoolOps (yieldVal, body))
369- return " linalg.pooling_ncw_sum" ;
370- }
371- // pooling_nwc_max
372- // pooling_nwc_min
373- // pooling_nwc_sum
374- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
375- // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
376- // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
377- if (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
378- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
379- matchConvDimExprPattern (indexingMaps, iIndex, 2 , oIndex, 2 )) {
380- if (bodyMatcherForMaxSignedPoolOps (yieldVal, body))
381- return " linalg.pooling_nwc_max" ;
382- if (bodyMatcherForMinSignedPoolOps (yieldVal, body))
383- return " linalg.pooling_nwc_min" ;
384- if (bodyMatcherForSumPoolOps (yieldVal, body))
385- return " linalg.pooling_nwc_sum" ;
386- }
252+ if (isaPoolingNcwMaxOp (genericOp))
253+ return " linalg.pooling_ncw_max" ;
254+ if (isaPoolingNcwSumOp (genericOp))
255+ return " linalg.pooling_ncw_sum" ;
256+ if (isaPoolingNwcMaxOp (genericOp))
257+ return " linalg.pooling_nwc_max" ;
258+ if (isaPoolingNwcMinOp (genericOp))
259+ return " linalg.pooling_nwc_min" ;
260+ if (isaPoolingNwcSumOp (genericOp))
261+ return " linalg.pooling_nwc_sum" ;
387262 return " " ;
388263}
389264
@@ -402,61 +277,22 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
402277 return " linalg.depthwise_conv_2d_nchw_chw" ;
403278 if (isaDepthwiseConv2DNhwcHwcOp (genericOp))
404279 return " linalg.depthwise_conv_2d_nhwc_hwc" ;
405-
406- ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
407- if (indexingMaps.size () < 3 ) return " " ;
408- unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
409280 if (isaConv3DOp (genericOp))
410281 return " linalg.conv_3d" ;
411-
412- Block *body = genericOp.getBlock ();
413- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
414- Value yieldVal = yieldOp.getOperand (0 );
415- // pooling_nchw_max
416- // pooling_nchw_sum
417- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
418- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
419- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
420- if (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
421- matchConvDimExprPattern (indexingMaps, iIndex, 1 , oIndex, 1 ) &&
422- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 0 , /* oDim=*/ 2 ) &&
423- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 1 , /* oDim=*/ 3 )) {
424- if (bodyMatcherForMaxSignedPoolOps (yieldVal, body))
425- return " linalg.pooling_nchw_max" ;
426- if (bodyMatcherForSumPoolOps (yieldVal, body))
427- return " linalg.pooling_nchw_sum" ;
428- }
429- // pooling_nhwc_max
430- // pooling_nhwc_min
431- // pooling_nhwc_sum
432- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
433- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
434- // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
435- if (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
436- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
437- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
438- matchConvDimExprPattern (indexingMaps, iIndex, 3 , oIndex, 3 )) {
439- if (bodyMatcherForMaxSignedPoolOps (yieldVal, body))
440- return " linalg.pooling_nhwc_max" ;
441- if (bodyMatcherForMinSignedPoolOps (yieldVal, body))
442- return " linalg.pooling_nhwc_min" ;
443- if (bodyMatcherForSumPoolOps (yieldVal, body))
444- return " linalg.pooling_nhwc_sum" ;
445- }
446- // pooling_nhwc_max_unsigned
447- // pooling_nhwc_min_unsigned
448- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
449- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
450- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
451- if (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
452- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
453- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
454- matchConvDimExprPattern (indexingMaps, iIndex, 3 , oIndex, 3 )) {
455- if (bodyMatcherForMaxUnsignedPoolOps (yieldVal, body))
456- return " linalg.pooling_nhwc_max_unsigned" ;
457- if (bodyMatcherForMinUnsignedPoolOps (yieldVal, body))
458- return " linalg.pooling_nhwc_min_unsigned" ;
459- }
282+ if (isaPoolingNchwMaxOp (genericOp))
283+ return " linalg.pooling_nchw_max" ;
284+ if (isaPoolingNchwSumOp (genericOp))
285+ return " linalg.pooling_nchw_sum" ;
286+ if (isaPoolingNhwcMaxOp (genericOp))
287+ return " linalg.pooling_nhwc_max" ;
288+ if (isaPoolingNhwcMinOp (genericOp))
289+ return " linalg.pooling_nhwc_min" ;
290+ if (isaPoolingNhwcSumOp (genericOp))
291+ return " linalg.pooling_nhwc_sum" ;
292+ if (isaPoolingNhwcMaxUnsignedOp (genericOp))
293+ return " linalg.pooling_nhwc_max_unsigned" ;
294+ if (isaPoolingNhwcMinUnsignedOp (genericOp))
295+ return " linalg.pooling_nhwc_min_unsigned" ;
460296 return " " ;
461297}
462298
@@ -491,31 +327,12 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
491327 return " linalg.depthwise_conv_3d_ncdhw_cdhw" ;
492328 if (isaDepthwiseConv3DNdhwcDhwcOp (genericOp))
493329 return " linalg.depthwise_conv_3d_ndhwc_dhwc" ;
494-
495- ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
496- if (indexingMaps.size () < 3 ) return " " ;
497- unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
498- Block *body = genericOp.getBlock ();
499- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
500- Value yieldVal = yieldOp.getOperand (0 );
501- // pooling_ndhwc_max
502- // pooling_ndhwc_min
503- // pooling_ndhwc_sum
504- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
505- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
506- // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
507- if (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
508- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
509- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
510- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 2 , /* oDim=*/ 3 ) &&
511- matchConvDimExprPattern (indexingMaps, iIndex, 4 , oIndex, 4 )) {
512- if (bodyMatcherForMaxSignedPoolOps (yieldVal, body))
513- return " linalg.pooling_ndhwc_max" ;
514- if (bodyMatcherForMinSignedPoolOps (yieldVal, body))
515- return " linalg.pooling_ndhwc_min" ;
516- if (bodyMatcherForSumPoolOps (yieldVal, body))
517- return " linalg.pooling_ndhwc_sum" ;
518- }
330+ if (isaPoolingNdhwcMaxOp (genericOp))
331+ return " linalg.pooling_ndhwc_max" ;
332+ if (isaPoolingNdhwcMinOp (genericOp))
333+ return " linalg.pooling_ndhwc_min" ;
334+ if (isaPoolingNdhwcSumOp (genericOp))
335+ return " linalg.pooling_ndhwc_sum" ;
519336 return " " ;
520337}
521338
0 commit comments