@@ -27,25 +27,18 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
2727public:
2828 using impl::LLVMGPUPromoteMatmulToFitMMAPassBase<
2929 LLVMGPUPromoteMatmulToFitMMAPass>::LLVMGPUPromoteMatmulToFitMMAPassBase;
30- explicit LLVMGPUPromoteMatmulToFitMMAPass (
31- const LLVMGPUMatmulPadOption &option) {
32- this ->targetDimensions .setValue (option);
33- }
3430 void getDependentDialects (DialectRegistry ®istry) const override {
3531 registry.insert <tensor::TensorDialect, linalg::LinalgDialect>();
3632 }
3733
3834 void padWithZeroValue (RewriterBase &rewriter, linalg::LinalgOp op,
39- ArrayRef<int64_t > paddingDims,
40- ArrayRef<int64_t > padToMultipleOf, bool noFold) const {
41- assert (paddingDims.size () == padToMultipleOf.size () &&
42- " invalid pad multiples for padding dimensions" );
43-
35+ ArrayRef<int64_t > padToMultipleOf) const {
4436 LLVM_DEBUG (llvm::dbgs () << " candidate: " << op << " \n " );
4537 OpBuilder::InsertionGuard guard (rewriter);
4638 rewriter.setInsertionPointAfter (op);
4739
48- SmallVector<bool > nofoldFlags (op.getNumDpsInputs (), noFold);
40+ SmallVector<int64_t > paddingDims =
41+ llvm::to_vector (llvm::seq<int64_t >(padToMultipleOf.size ()));
4942
5043 SmallVector<Attribute> paddingValueAttributes;
5144 for (auto &operand : op->getOpOperands ()) {
@@ -58,7 +51,6 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
5851 .setPaddingDimensions (paddingDims)
5952 .setPaddingValues (paddingValueAttributes)
6053 .setPadToMultipleOf (padToMultipleOf)
61- .setNofoldFlags (nofoldFlags)
6254 .setCopyBackOp (linalg::LinalgPaddingOptions::CopyBackOp::None);
6355
6456 FailureOr<linalg::LinalgOp> result =
@@ -72,26 +64,6 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
7264 MLIRContext *ctx = &getContext ();
7365 auto funcOp = getOperation ();
7466
75- // Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so
76- // we can kick canonicalization patterns to fold outer tensor.pad ops away.
77- bool noFold = false ;
78- utils::IteratorType targetIterType = utils::IteratorType::parallel;
79- switch (targetDimensions) {
80- case LLVMGPUMatmulPadOption::ParallelDims:
81- LLVM_DEBUG (llvm::dbgs () << " padding parallel dims\n " );
82- targetIterType = utils::IteratorType::parallel;
83- noFold = false ;
84- break ;
85- case LLVMGPUMatmulPadOption::ReductionDims:
86- LLVM_DEBUG (llvm::dbgs () << " padding reduction dims\n " );
87- targetIterType = utils::IteratorType::reduction;
88- noFold = true ;
89- break ;
90- default : // Unreachable.
91- assert (false );
92- break ;
93- };
94-
9567 SmallVector<linalg::LinalgOp> candidates;
9668 funcOp->walk ([&](linalg::LinalgOp op) {
9769 if (linalg::isaContractionOpInterface (op)) {
@@ -101,46 +73,27 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
10173
10274 IRRewriter rewriter (ctx);
10375 for (linalg::LinalgOp op : candidates) {
104- SmallVector<int64_t > padMultiples (op.getNumLoops (), 1 );
10576 auto config = dyn_cast_or_null<IREE::GPU::LoweringConfigAttr>(
10677 getLoweringConfig (op));
107- if (config) {
108- switch (targetDimensions) {
109- case LLVMGPUMatmulPadOption::ParallelDims:
110- padMultiples = config.getStaticTilingLevelSizes (
111- static_cast <unsigned >(IREE::GPU::TilingLevel::Workgroup), op);
112- break ;
113- case LLVMGPUMatmulPadOption::ReductionDims:
114- padMultiples = config.getStaticTilingLevelSizes (
115- static_cast <unsigned >(IREE::GPU::TilingLevel::Reduction), op);
116- break ;
117- default :
118- assert (false && " Unexpected target dimensions" );
119- break ;
120- }
78+ if (!config) {
79+ continue ;
12180 }
12281
123- // Populate padding dimensions.
124- SmallVector<int64_t > paddingDimensions;
125- for (auto [idx, iter] : llvm::enumerate (op.getIteratorTypesArray ())) {
126- if (iter == targetIterType) {
127- paddingDimensions.push_back (idx);
128- }
129- }
82+ SmallVector<int64_t > wgTiles = config.getStaticTilingLevelSizes (
83+ static_cast <unsigned >(IREE::GPU::TilingLevel::Workgroup), op);
84+ SmallVector<int64_t > redTiles = config.getStaticTilingLevelSizes (
85+ static_cast <unsigned >(IREE::GPU::TilingLevel::Reduction), op);
13086
131- // Populate tile sizes. We pad to multiples of workgroup/reduction
132- // tile sizes based on the selected target tiling dimensions.
133- // This pass is ran after the select target tiling is done to pad
134- // all dimensions to the select tile sizes.
135- SmallVector<int64_t > padToMultipleOf;
136- for (int64_t dim : paddingDimensions) {
137- if (padMultiples[dim] != 0 ) {
138- padToMultipleOf.push_back (padMultiples[dim]);
139- }
87+ // Populate padding dimensions to maximum of possible tile sizes.
88+ SmallVector<int64_t > padToMultipleOf (op.getNumLoops (), 1 );
89+ for (auto [wgTile, redTile, padMultiple] :
90+ llvm::zip_equal (wgTiles, redTiles, padToMultipleOf)) {
91+ padMultiple = std::max ({wgTile, redTile, padMultiple});
14092 }
93+ SmallVector<int64_t > paddingDimensions =
94+ llvm::to_vector (llvm::seq<int64_t >(op.getNumLoops ()));
14195
142- padWithZeroValue (rewriter, op, paddingDimensions, padToMultipleOf,
143- noFold);
96+ padWithZeroValue (rewriter, op, padToMultipleOf);
14497 }
14598
14699 {
@@ -156,58 +109,8 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
156109 return signalPassFailure ();
157110 }
158111 }
159-
160- // XXX(hanchung): This is needed for pad op fusion, which will remove
161- // outer pad ops. I.e., it mainly wants to remove first pad op in the
162- // pad->extract_slice->pad chain, while the canonicalization pattern can
163- // only recognize slice->pad->slice->pad.
164- {
165- SmallVector<tensor::PadOp> padOps;
166- funcOp.walk ([&](tensor::PadOp op) { padOps.push_back (op); });
167- for (auto op : padOps) {
168- auto srcExtractSliceOp =
169- op.getSource ().getDefiningOp <tensor::ExtractSliceOp>();
170- if (!srcExtractSliceOp) {
171- continue ;
172- }
173- auto producerPadOp =
174- srcExtractSliceOp.getSource ().getDefiningOp <tensor::PadOp>();
175- if (!producerPadOp) {
176- continue ;
177- }
178- auto src = producerPadOp.getSource ()
179- .getDefiningOp <IREE::Flow::DispatchTensorLoadOp>();
180- if (!src) {
181- continue ;
182- }
183-
184- rewriter.setInsertionPointAfter (src);
185- SmallVector<OpFoldResult> sizes =
186- tensor::getMixedSizes (rewriter, op.getLoc (), src);
187- SmallVector<OpFoldResult> offsets (sizes.size (),
188- rewriter.getIndexAttr (0 ));
189- SmallVector<OpFoldResult> strides (sizes.size (),
190- rewriter.getIndexAttr (1 ));
191- auto extractSliceOp = rewriter.create <tensor::ExtractSliceOp>(
192- op.getLoc (), src.getResult (), offsets, sizes, strides);
193- rewriter.startOpModification (op);
194- producerPadOp.getSourceMutable ().assign (extractSliceOp.getResult ());
195- rewriter.finalizeOpModification (op);
196- }
197-
198- RewritePatternSet patterns (ctx);
199- tensor::PadOp::getCanonicalizationPatterns (patterns, ctx);
200- if (failed (applyPatternsAndFoldGreedily (funcOp, std::move (patterns)))) {
201- return signalPassFailure ();
202- }
203- }
204112 }
205113};
206114} // namespace
207115
208- std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
209- createLLVMGPUPromoteMatmulToFitMMAPass (LLVMGPUMatmulPadOption option) {
210- return std::make_unique<LLVMGPUPromoteMatmulToFitMMAPass>(option);
211- }
212-
213116} // namespace mlir::iree_compiler
0 commit comments