@@ -36,23 +36,17 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
3636 }
3737
3838 void padWithZeroValue (RewriterBase &rewriter, linalg::LinalgOp op,
39- utils::IteratorType targetIterType, bool nofold) const {
39+ ArrayRef<int64_t > paddingDims,
40+ ArrayRef<int64_t > padToMultipleOf, bool noFold) const {
41+ assert (paddingDims.size () == padToMultipleOf.size () &&
42+ " invalid pad multiples for padding dimensions" );
43+
4044 LLVM_DEBUG (llvm::dbgs () << " candidate: " << op << " \n " );
4145 OpBuilder::InsertionGuard guard (rewriter);
4246 rewriter.setInsertionPointAfter (op);
4347
44- SmallVector<int64_t > paddingDims;
45- for (auto [index, iterType] : llvm::enumerate (op.getIteratorTypesArray ())) {
46- if (iterType == targetIterType) {
47- paddingDims.push_back (index);
48- }
49- }
50-
51- SmallVector<bool > packPaddings (op.getNumDpsInputs (), nofold);
48+ SmallVector<bool > packPaddings (op.getNumDpsInputs (), noFold);
5249
53- // One is enough because they will essentially be padded to corresponding
54- // tile sizes, which should be multiple of MMA shapes.
55- SmallVector<int64_t > padToMultipleOf (paddingDims.size (), 1 );
5650 SmallVector<Attribute> paddingValueAttributes;
5751 for (auto &operand : op->getOpOperands ()) {
5852 auto elemType = getElementTypeOrSelf (operand.get ().getType ());
@@ -80,18 +74,18 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
8074
8175 // Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so
8276 // we can kick canonicalization patterns to fold outer tensor.pad ops away.
83- bool nofold = false ;
77+ bool noFold = false ;
8478 utils::IteratorType targetIterType = utils::IteratorType::parallel;
8579 switch (targetDimensions) {
8680 case LLVMGPUMatmulPadOption::ParallelDims:
8781 LLVM_DEBUG (llvm::dbgs () << " padding parallel dims\n " );
8882 targetIterType = utils::IteratorType::parallel;
89- nofold = false ;
83+ noFold = false ;
9084 break ;
9185 case LLVMGPUMatmulPadOption::ReductionDims:
9286 LLVM_DEBUG (llvm::dbgs () << " padding reduction dims\n " );
9387 targetIterType = utils::IteratorType::reduction;
94- nofold = true ;
88+ noFold = true ;
9589 break ;
9690 default : // Unreachable.
9791 assert (false );
@@ -106,8 +100,47 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
106100 });
107101
108102 IRRewriter rewriter (ctx);
109- for (auto op : candidates) {
110- padWithZeroValue (rewriter, op, targetIterType, nofold);
103+ for (linalg::LinalgOp op : candidates) {
104+ SmallVector<int64_t > padMultiples (op.getNumLoops (), 1 );
105+ auto config = dyn_cast_or_null<IREE::GPU::LoweringConfigAttr>(
106+ getLoweringConfig (op));
107+ if (config) {
108+ switch (targetDimensions) {
109+ case LLVMGPUMatmulPadOption::ParallelDims:
110+ padMultiples = config.getStaticTilingLevelSizes (
111+ static_cast <unsigned >(IREE::GPU::TilingLevel::Workgroup), op);
112+ break ;
113+ case LLVMGPUMatmulPadOption::ReductionDims:
114+ padMultiples = config.getStaticTilingLevelSizes (
115+ static_cast <unsigned >(IREE::GPU::TilingLevel::Reduction), op);
116+ break ;
117+ default :
118+ assert (false && " Unexpected target dimensions" );
119+ break ;
120+ }
121+ }
122+
123+ // Populate padding dimensions.
124+ SmallVector<int64_t > paddingDimensions;
125+ for (auto [idx, iter] : llvm::enumerate (op.getIteratorTypesArray ())) {
126+ if (iter == targetIterType) {
127+ paddingDimensions.push_back (idx);
128+ }
129+ }
130+
131+ // Populate tile sizes. We pad to multiples of workgroup/reduction
132+ // tile sizes based on the selected target tiling dimensions.
133+ // This pass is ran after the select target tiling is done to pad
134+ // all dimensions to the select tile sizes.
135+ SmallVector<int64_t > padToMultipleOf;
136+ for (int64_t dim : paddingDimensions) {
137+ if (padMultiples[dim] != 0 ) {
138+ padToMultipleOf.push_back (padMultiples[dim]);
139+ }
140+ }
141+
142+ padWithZeroValue (rewriter, op, paddingDimensions, padToMultipleOf,
143+ noFold);
111144 }
112145
113146 {
0 commit comments