Skip to content

Commit 1aa5825

Browse files
authored
[LLVMGPU] Combine parallel and reduction padding in LLVMGPUPadAndVectorDistribute (#18771)
Since #18748 tensor.pad can be fused in with tiling. This patch combines the parallel and reduction padding passes into a single pass that pads at once, and the pads are later fused during tiling.
1 parent 1fc6e5b commit 1aa5825

File tree

6 files changed

+41
-253
lines changed

6 files changed

+41
-253
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp

Lines changed: 17 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,18 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
2727
public:
2828
using impl::LLVMGPUPromoteMatmulToFitMMAPassBase<
2929
LLVMGPUPromoteMatmulToFitMMAPass>::LLVMGPUPromoteMatmulToFitMMAPassBase;
30-
explicit LLVMGPUPromoteMatmulToFitMMAPass(
31-
const LLVMGPUMatmulPadOption &option) {
32-
this->targetDimensions.setValue(option);
33-
}
3430
void getDependentDialects(DialectRegistry &registry) const override {
3531
registry.insert<tensor::TensorDialect, linalg::LinalgDialect>();
3632
}
3733

3834
void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op,
39-
ArrayRef<int64_t> paddingDims,
40-
ArrayRef<int64_t> padToMultipleOf, bool noFold) const {
41-
assert(paddingDims.size() == padToMultipleOf.size() &&
42-
"invalid pad multiples for padding dimensions");
43-
35+
ArrayRef<int64_t> padToMultipleOf) const {
4436
LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n");
4537
OpBuilder::InsertionGuard guard(rewriter);
4638
rewriter.setInsertionPointAfter(op);
4739

48-
SmallVector<bool> nofoldFlags(op.getNumDpsInputs(), noFold);
40+
SmallVector<int64_t> paddingDims =
41+
llvm::to_vector(llvm::seq<int64_t>(padToMultipleOf.size()));
4942

5043
SmallVector<Attribute> paddingValueAttributes;
5144
for (auto &operand : op->getOpOperands()) {
@@ -58,7 +51,6 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
5851
.setPaddingDimensions(paddingDims)
5952
.setPaddingValues(paddingValueAttributes)
6053
.setPadToMultipleOf(padToMultipleOf)
61-
.setNofoldFlags(nofoldFlags)
6254
.setCopyBackOp(linalg::LinalgPaddingOptions::CopyBackOp::None);
6355

6456
FailureOr<linalg::LinalgOp> result =
@@ -72,26 +64,6 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
7264
MLIRContext *ctx = &getContext();
7365
auto funcOp = getOperation();
7466

75-
// Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so
76-
// we can kick canonicalization patterns to fold outer tensor.pad ops away.
77-
bool noFold = false;
78-
utils::IteratorType targetIterType = utils::IteratorType::parallel;
79-
switch (targetDimensions) {
80-
case LLVMGPUMatmulPadOption::ParallelDims:
81-
LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n");
82-
targetIterType = utils::IteratorType::parallel;
83-
noFold = false;
84-
break;
85-
case LLVMGPUMatmulPadOption::ReductionDims:
86-
LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n");
87-
targetIterType = utils::IteratorType::reduction;
88-
noFold = true;
89-
break;
90-
default: // Unreachable.
91-
assert(false);
92-
break;
93-
};
94-
9567
SmallVector<linalg::LinalgOp> candidates;
9668
funcOp->walk([&](linalg::LinalgOp op) {
9769
if (linalg::isaContractionOpInterface(op)) {
@@ -101,46 +73,27 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
10173

10274
IRRewriter rewriter(ctx);
10375
for (linalg::LinalgOp op : candidates) {
104-
SmallVector<int64_t> padMultiples(op.getNumLoops(), 1);
10576
auto config = dyn_cast_or_null<IREE::GPU::LoweringConfigAttr>(
10677
getLoweringConfig(op));
107-
if (config) {
108-
switch (targetDimensions) {
109-
case LLVMGPUMatmulPadOption::ParallelDims:
110-
padMultiples = config.getStaticTilingLevelSizes(
111-
static_cast<unsigned>(IREE::GPU::TilingLevel::Workgroup), op);
112-
break;
113-
case LLVMGPUMatmulPadOption::ReductionDims:
114-
padMultiples = config.getStaticTilingLevelSizes(
115-
static_cast<unsigned>(IREE::GPU::TilingLevel::Reduction), op);
116-
break;
117-
default:
118-
assert(false && "Unexpected target dimensions");
119-
break;
120-
}
78+
if (!config) {
79+
continue;
12180
}
12281

123-
// Populate padding dimensions.
124-
SmallVector<int64_t> paddingDimensions;
125-
for (auto [idx, iter] : llvm::enumerate(op.getIteratorTypesArray())) {
126-
if (iter == targetIterType) {
127-
paddingDimensions.push_back(idx);
128-
}
129-
}
82+
SmallVector<int64_t> wgTiles = config.getStaticTilingLevelSizes(
83+
static_cast<unsigned>(IREE::GPU::TilingLevel::Workgroup), op);
84+
SmallVector<int64_t> redTiles = config.getStaticTilingLevelSizes(
85+
static_cast<unsigned>(IREE::GPU::TilingLevel::Reduction), op);
13086

131-
// Populate tile sizes. We pad to multiples of workgroup/reduction
132-
// tile sizes based on the selected target tiling dimensions.
133-
// This pass is ran after the select target tiling is done to pad
134-
// all dimensions to the select tile sizes.
135-
SmallVector<int64_t> padToMultipleOf;
136-
for (int64_t dim : paddingDimensions) {
137-
if (padMultiples[dim] != 0) {
138-
padToMultipleOf.push_back(padMultiples[dim]);
139-
}
87+
// Populate padding dimensions to maximum of possible tile sizes.
88+
SmallVector<int64_t> padToMultipleOf(op.getNumLoops(), 1);
89+
for (auto [wgTile, redTile, padMultiple] :
90+
llvm::zip_equal(wgTiles, redTiles, padToMultipleOf)) {
91+
padMultiple = std::max({wgTile, redTile, padMultiple});
14092
}
93+
SmallVector<int64_t> paddingDimensions =
94+
llvm::to_vector(llvm::seq<int64_t>(op.getNumLoops()));
14195

142-
padWithZeroValue(rewriter, op, paddingDimensions, padToMultipleOf,
143-
noFold);
96+
padWithZeroValue(rewriter, op, padToMultipleOf);
14497
}
14598

14699
{
@@ -156,58 +109,8 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
156109
return signalPassFailure();
157110
}
158111
}
159-
160-
// XXX(hanchung): This is needed for pad op fusion, which will remove
161-
// outer pad ops. I.e., it mainly wants to remove first pad op in the
162-
// pad->extract_slice->pad chain, while the canonicalization pattern can
163-
// only recognize slice->pad->slice->pad.
164-
{
165-
SmallVector<tensor::PadOp> padOps;
166-
funcOp.walk([&](tensor::PadOp op) { padOps.push_back(op); });
167-
for (auto op : padOps) {
168-
auto srcExtractSliceOp =
169-
op.getSource().getDefiningOp<tensor::ExtractSliceOp>();
170-
if (!srcExtractSliceOp) {
171-
continue;
172-
}
173-
auto producerPadOp =
174-
srcExtractSliceOp.getSource().getDefiningOp<tensor::PadOp>();
175-
if (!producerPadOp) {
176-
continue;
177-
}
178-
auto src = producerPadOp.getSource()
179-
.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
180-
if (!src) {
181-
continue;
182-
}
183-
184-
rewriter.setInsertionPointAfter(src);
185-
SmallVector<OpFoldResult> sizes =
186-
tensor::getMixedSizes(rewriter, op.getLoc(), src);
187-
SmallVector<OpFoldResult> offsets(sizes.size(),
188-
rewriter.getIndexAttr(0));
189-
SmallVector<OpFoldResult> strides(sizes.size(),
190-
rewriter.getIndexAttr(1));
191-
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
192-
op.getLoc(), src.getResult(), offsets, sizes, strides);
193-
rewriter.startOpModification(op);
194-
producerPadOp.getSourceMutable().assign(extractSliceOp.getResult());
195-
rewriter.finalizeOpModification(op);
196-
}
197-
198-
RewritePatternSet patterns(ctx);
199-
tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
200-
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
201-
return signalPassFailure();
202-
}
203-
}
204112
}
205113
};
206114
} // namespace
207115

208-
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
209-
createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option) {
210-
return std::make_unique<LLVMGPUPromoteMatmulToFitMMAPass>(option);
211-
}
212-
213116
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -858,25 +858,20 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
858858
funcPassManager.addPass(createCSEPass());
859859

860860
if (usePadToModelSharedMemcpy) {
861-
LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ParallelDims;
862-
funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option));
861+
funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass());
863862
}
864863

865864
// Tile to reduction loops.
866865
{
867866
GPUApplyTilingLevelPassOptions options;
868867
options.tilingLevel = IREE::GPU::TilingLevel::Reduction;
868+
options.allowZeroSlices = true;
869869
funcPassManager.addPass(createGPUApplyTilingLevelPass(options));
870870
funcPassManager.addPass(affine::createLoopCoalescingPass());
871871
funcPassManager.addPass(createCanonicalizerPass());
872872
funcPassManager.addPass(createCSEPass());
873873
}
874874

875-
if (usePadToModelSharedMemcpy) {
876-
LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ReductionDims;
877-
funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option));
878-
}
879-
880875
funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
881876
funcPassManager.addPass(createCanonicalizerPass());
882877
funcPassManager.addPass(createCSEPass());

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,6 @@ verifyGPUMatmulPipeline(Operation *op,
103103
// Wrappers that not use tablegen options.
104104
//------------------------------------------------------------------------------
105105

106-
enum class LLVMGPUMatmulPadOption { ParallelDims, ReductionDims };
107-
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
108-
createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option);
109-
110106
enum class GPUTensorCoreType {
111107
WMMA = 0,
112108
MMA_SYNC = 1,

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,6 @@ def LLVMGPUPrefetchSharedMemoryPass :
105105
def LLVMGPUPromoteMatmulToFitMMAPass :
106106
InterfacePass<"iree-llvmgpu-promote-matmul-to-fit-mma", "mlir::FunctionOpInterface"> {
107107
let summary = "Pass to promote contraction ops to fit mma shapes";
108-
let options = [
109-
Option<"targetDimensions", "target-dimensions", "mlir::iree_compiler::LLVMGPUMatmulPadOption",
110-
/*default=*/"mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims",
111-
"Select the strategy to control how multi_reduction is lowered.",
112-
[{::llvm::cl::values(
113-
clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims,
114-
"parallel",
115-
"Pad all the parallel dims for contraction ops."),
116-
clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ReductionDims,
117-
"reduction",
118-
"Pad all the reduction dims for contraction ops.")
119-
)}]>
120-
];
121108
}
122109

123110
def LLVMGPUSelectLoweringStrategyPass :

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
511511
// CHECK: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
512512
// CHECK: vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]]
513513
// CHECK: vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]]
514-
// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1265 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
514+
// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1280 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
515515
// CHECK-DAG: %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
516516
// CHECK-DAG: %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
517517
// CHECK: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]
@@ -581,9 +581,11 @@ hal.executable public @pad_batch_matmul {
581581
// CHECK-SAME: memref<196x16x24xf32
582582
// CHECK-SAME: vector<1x1x1xf32>
583583
// RHS
584+
// The dynamic dimension should be removed after:
585+
// https://github.com/llvm/llvm-project/pull/112236
584586
// CHECK: vector.transfer_read
585-
// CHECK-SAME: in_bounds = [true, true, false]
586-
// CHECK-SAME: memref<1x8x24xf32
587+
// CHECK-SAME: in_bounds = [true, false, false]
588+
// CHECK-SAME: memref<1x?x24xf32
587589
// CHECK-SAME: vector<1x1x2xf32>
588590
// CHECK: scf.yield
589591
// OUTPUT

0 commit comments

Comments
 (0)