@@ -240,6 +240,225 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
240240 return iteratorType == utils::IteratorType::reduction;
241241}
242242
243+ // -------------------------------
244+ // ---------- CONV ---------------
245+ // -------------------------------
246+
247+ // / Utility to match block body for linalg.pool* ops.
248+ template <typename ... OpTypes>
249+ static bool bodyMatcherForPoolOps (Value yieldVal, Block *body) {
250+ Operation *defOp = yieldVal.getDefiningOp ();
251+ // if (!defOp) return false;
252+ if (!(isa_and_present<OpTypes>(defOp) || ...)) return false ;
253+
254+ BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand (0 ));
255+ BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand (1 ));
256+ if (!lhsArg || !rhsArg) return false ;
257+ return true ;
258+ }
259+
260+ static bool bodyMatcherForMaxSignedPoolOps (Value yieldVal, Block *body) {
261+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
262+ }
263+
264+ static bool bodyMatcherForMaxUnsignedPoolOps (Value yieldVal, Block *body) {
265+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
266+ }
267+
268+ static bool bodyMatcherForMinSignedPoolOps (Value yieldVal, Block *body) {
269+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
270+ }
271+
272+ static bool bodyMatcherForMinUnsignedPoolOps (Value yieldVal, Block *body) {
273+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
274+ }
275+
276+ static bool bodyMatcherForSumPoolOps (Value yieldVal, Block *body) {
277+ return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
278+ }
279+
280+ static mlir::AffineExpr getAffineMapDim (ArrayAttr indexingMaps,
281+ uint32_t mapIndex, uint32_t dimIndex) {
282+ auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue ();
283+ if (dimIndex < affineMap.getNumResults ())
284+ return affineMap.getResult (dimIndex);
285+ return nullptr ;
286+ }
287+
288+ // Check if `expr` is either:
289+ // - a dimension expr alone (implying *1), or
290+ // - a multiplication of dimension expr by constant.
291+ static bool isDimTimesConstantOrDimOnly (AffineExpr expr, AffineExpr &dim, int64_t &constantValue) {
292+ if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
293+ dim = dExpr;
294+ constantValue = 1 ;
295+ return true ;
296+ }
297+
298+ auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
299+ if (!mulExpr || mulExpr.getKind () != AffineExprKind::Mul)
300+ return false ;
301+
302+ AffineExpr lhs = mulExpr.getLHS ();
303+ AffineExpr rhs = mulExpr.getRHS ();
304+
305+ if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
306+ if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
307+ dim = dExpr;
308+ constantValue = cst.getValue ();
309+ return true ;
310+ }
311+ }
312+ if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
313+ if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
314+ dim = dExpr;
315+ constantValue = cst.getValue ();
316+ return true ;
317+ }
318+ }
319+ return false ;
320+ }
321+
322+ static bool matchConvDimAddExprPattern (ArrayAttr indexingMaps, unsigned iDim, unsigned fDim , unsigned oDim) {
323+ unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
324+ AffineExpr inpExpr = getAffineMapDim (indexingMaps, iIndex, iDim);
325+ auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
326+ if (!addExpr || addExpr.getKind () != AffineExprKind::Add)
327+ return false ;
328+
329+ AffineExpr dim0, dim1;
330+ // TODO(Abhishek-Varma): Use this information in specialize.cpp.
331+ int64_t c0, c1;
332+
333+ if (isDimTimesConstantOrDimOnly (addExpr.getLHS (), dim0, c0) &&
334+ isDimTimesConstantOrDimOnly (addExpr.getRHS (), dim1, c1)) {
335+ // Pattern matched with dims and constants extracted.
336+ AffineExpr fExpr = getAffineMapDim (indexingMaps, fIndex , fDim );
337+ AffineExpr oExpr = getAffineMapDim (indexingMaps, oIndex, oDim);
338+ return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr));
339+ }
340+ return false ;
341+ }
342+
343+ static bool matchConvDimExprPattern (ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
344+ return getAffineMapDim (indexingMaps, aIndex, aDim) == getAffineMapDim (indexingMaps, bIndex, bDim);
345+ }
346+
347+ static bool verifyConvIndexingMapSizes (ArrayAttr indexingMaps, ArrayRef<int64_t > expectedSizes) {
348+ if (indexingMaps.size () != expectedSizes.size ()) return false ;
349+
350+ for (auto [indexingMap, expectedSize] : llvm::zip_equal (indexingMaps, expectedSizes)) {
351+ auto affineMap = cast<AffineMapAttr>(indexingMap).getValue ();
352+ if (affineMap.getNumResults () != expectedSize) return false ;
353+ }
354+ return true ;
355+ }
356+
357+ bool isaConv1DOp (LinalgOp op) {
358+ if (isa<linalg::Conv1DOp>(op)) return true ;
359+
360+ if (!isaConvolutionOpInterface (op)) return false ;
361+
362+ ArrayAttr indexingMaps = op.getIndexingMaps ();
363+ if (!verifyConvIndexingMapSizes (indexingMaps, {1 ,1 ,1 })) return false ;
364+
365+ // #map = affine_map<(d0, d1) -> (d0 + d1)>
366+ // #map1 = affine_map<(d0, d1) -> (d1)>
367+ // #map2 = affine_map<(d0, d1) -> (d0)>
368+ return matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 0 , /* fDim=*/ 0 , /* oDim=*/ 0 );
369+ }
370+
371+ bool isaConv1DNwcWcfOp (LinalgOp op) {
372+ if (isa<linalg::Conv1DNwcWcfOp>(op)) return true ;
373+
374+ if (!isaConvolutionOpInterface (op)) return false ;
375+
376+ ArrayAttr indexingMaps = op.getIndexingMaps ();
377+ if (!verifyConvIndexingMapSizes (indexingMaps, {3 ,3 ,3 })) return false ;
378+
379+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 2 ;
380+ // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
381+ // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
382+ // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
383+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
384+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
385+ matchConvDimExprPattern (indexingMaps, iIndex, 2 , fIndex , 1 ) &&
386+ matchConvDimExprPattern (indexingMaps, fIndex , 2 , oIndex, 2 ));
387+ }
388+
389+ bool isaConv1DNcwFcwOp (LinalgOp op) {
390+ if (isa<linalg::Conv1DNcwFcwOp>(op)) return true ;
391+
392+ if (!isaConvolutionOpInterface (op)) return false ;
393+
394+ ArrayAttr indexingMaps = op.getIndexingMaps ();
395+ if (!verifyConvIndexingMapSizes (indexingMaps, {3 ,3 ,3 })) return false ;
396+
397+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 2 ;
398+ // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
399+ // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
400+ // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
401+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
402+ matchConvDimExprPattern (indexingMaps, iIndex, 1 , fIndex , 1 ) &&
403+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 2 , /* oDim=*/ 2 ) &&
404+ matchConvDimExprPattern (indexingMaps, fIndex , 0 , oIndex, 1 ));
405+ }
406+
407+ bool isaDepthwiseConv1DNcwCwOp (LinalgOp op) {
408+ if (isa<linalg::DepthwiseConv1DNcwCwOp>(op)) return true ;
409+
410+ if (!isaConvolutionOpInterface (op)) return false ;
411+
412+ ArrayAttr indexingMaps = op.getIndexingMaps ();
413+ if (!verifyConvIndexingMapSizes (indexingMaps, {3 ,2 ,3 })) return false ;
414+
415+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 2 ;
416+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
417+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
418+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
419+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
420+ matchConvDimExprPattern (indexingMaps, iIndex, 1 , fIndex , 0 ) &&
421+ matchConvDimExprPattern (indexingMaps, iIndex, 1 , oIndex, 1 ) &&
422+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ));
423+ }
424+
425+ bool isaDepthwiseConv1DNwcWcOp (LinalgOp op) {
426+ if (isa<linalg::DepthwiseConv1DNwcWcOp>(op)) return true ;
427+
428+ if (!isaConvolutionOpInterface (op)) return false ;
429+
430+ ArrayAttr indexingMaps = op.getIndexingMaps ();
431+ if (!verifyConvIndexingMapSizes (indexingMaps, {3 ,2 ,3 })) return false ;
432+
433+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 2 ;
434+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
435+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
436+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
437+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
438+ matchConvDimExprPattern (indexingMaps, iIndex, 2 , fIndex , 1 ) &&
439+ matchConvDimExprPattern (indexingMaps, iIndex, 2 , oIndex, 2 ) &&
440+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ));
441+ }
442+
443+ bool isaDepthwiseConv1DNwcWcmOp (LinalgOp op) {
444+ if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op)) return true ;
445+
446+ if (!isaConvolutionOpInterface (op)) return false ;
447+
448+ ArrayAttr indexingMaps = op.getIndexingMaps ();
449+ if (!verifyConvIndexingMapSizes (indexingMaps, {3 ,3 ,4 })) return false ;
450+
451+ unsigned iIndex = 0 , fIndex = 1 , oIndex = 2 ;
452+ // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
453+ // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
454+ // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
455+ return (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
456+ matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
457+ matchConvDimExprPattern (indexingMaps, iIndex, 2 , fIndex , 1 ) &&
458+ matchConvDimExprPattern (indexingMaps, iIndex, 2 , oIndex, 2 ) &&
459+ matchConvDimExprPattern (indexingMaps, fIndex , 2 , oIndex, 3 ));
460+ }
461+
243462Value makeComposedPadHighOp (OpBuilder &b, Location loc, RankedTensorType type,
244463 Value source, Value pad, bool nofold,
245464 ValueRange typeDynDims) {
0 commit comments