Skip to content

Commit c82c3d3

Browse files
Make dilations/strides mandatory
1 parent f557fca commit c82c3d3

File tree

2 files changed

+19
-26
lines changed

2 files changed

+19
-26
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,10 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
115115
//===----------------------------------------------------------------------===//
116116

117117
/// Given a linalg `op` this function returns true if it is a convolution op of
118-
/// type `ConvOpTy` and populate the optional `dilations` and `strides`
119-
/// arguments, if present.
118+
/// type `ConvOpTy` and populate `dilations` and `strides` arguments.
120119
template <typename ConvOpTy>
121-
bool isaConvolutionOpOfType(LinalgOp op,
122-
SmallVector<int64_t> *dilations = nullptr,
123-
SmallVector<int64_t> *strides = nullptr);
120+
bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
121+
SmallVector<int64_t> *strides);
124122

125123
//===----------------------------------------------------------------------===//
126124
// Fusion / Tiling utilities

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -438,19 +438,6 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
438438
return true;
439439
}
440440

441-
/// Utility to update `dilations` and `strides` by copy the corresponding data
442-
/// from `tempDilations` and `tempStrides`.
443-
static void updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
444-
SmallVector<int64_t> *strides,
445-
ArrayRef<int64_t> tempDilations,
446-
ArrayRef<int64_t> tempStrides) {
447-
if (!(dilations && strides))
448-
return;
449-
*dilations = SmallVector<int64_t>(tempDilations);
450-
*strides = SmallVector<int64_t>(tempStrides);
451-
return;
452-
}
453-
454441
// ---------------------------------------------
455442
// Matchers for specific convolution operation.
456443
// ---------------------------------------------
@@ -497,7 +484,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
497484
// Match body
498485
if (!bodyMatcherForConvolutionOps(yieldVal, body))
499486
return false;
500-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
487+
*dilations = SmallVector<int64_t>(tempDilations);
488+
*strides = SmallVector<int64_t>(tempStrides);
501489
return true;
502490
}
503491

@@ -547,7 +535,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
547535
// Match body
548536
if (!bodyMatcherForConvolutionOps(yieldVal, body))
549537
return false;
550-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
538+
*dilations = SmallVector<int64_t>(tempDilations);
539+
*strides = SmallVector<int64_t>(tempStrides);
551540
return true;
552541
}
553542

@@ -608,7 +597,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
608597
// Match body
609598
if (!bodyMatcherForConvolutionOps(yieldVal, body))
610599
return false;
611-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
600+
*dilations = SmallVector<int64_t>(tempDilations);
601+
*strides = SmallVector<int64_t>(tempStrides);
612602
return true;
613603
}
614604

@@ -655,7 +645,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
655645
// Match body
656646
if (!bodyMatcherForMaxSignedPoolOps(yieldVal, body))
657647
return false;
658-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
648+
*dilations = SmallVector<int64_t>(tempDilations);
649+
*strides = SmallVector<int64_t>(tempStrides);
659650
return true;
660651
}
661652

@@ -702,7 +693,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
702693
// Match body
703694
if (!bodyMatcherForMinSignedPoolOps(yieldVal, body))
704695
return false;
705-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
696+
*dilations = SmallVector<int64_t>(tempDilations);
697+
*strides = SmallVector<int64_t>(tempStrides);
706698
return true;
707699
}
708700

@@ -749,7 +741,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
749741
// Match body
750742
if (!bodyMatcherForSumPoolOps(yieldVal, body))
751743
return false;
752-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
744+
*dilations = SmallVector<int64_t>(tempDilations);
745+
*strides = SmallVector<int64_t>(tempStrides);
753746
return true;
754747
}
755748

@@ -796,7 +789,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
796789
// Match body
797790
if (!bodyMatcherForMaxUnsignedPoolOps(yieldVal, body))
798791
return false;
799-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
792+
*dilations = SmallVector<int64_t>(tempDilations);
793+
*strides = SmallVector<int64_t>(tempStrides);
800794
return true;
801795
}
802796

@@ -843,7 +837,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
843837
// Match body
844838
if (!bodyMatcherForMinUnsignedPoolOps(yieldVal, body))
845839
return false;
846-
updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
840+
*dilations = SmallVector<int64_t>(tempDilations);
841+
*strides = SmallVector<int64_t>(tempStrides);
847842
return true;
848843
}
849844

0 commit comments

Comments
 (0)