Skip to content

Commit a5e2d27

Browse files
[GPU] Pad fusion support for TileAndDistributeToWorkgroupsUsingForall (iree-org#20258)
Do not generate zero slice guard by using an explicit pattern to swap extract slice with pad in TileAndDistributeToWorkgroupsUsingForall. Fixes : iree-org#20253 Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent deb8435 commit a5e2d27

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,12 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
503503
// TODO(Max191): Replace populateSwapExtractWithExpandPattern with upstream
504504
// MLIR version once it is available (llvm-project/pull/126898).
505505
populateSwapExtractWithExpandPattern(cleanupPatterns);
506+
// When fusing pads we do not want to generate zeroSliceGuards when doing
507+
// workgroup tiling. In `GPUApplyTilingLevelPass` we do have an option called
508+
// `allowZeroSlices` that can control this but we do not want these
509+
// generated if workgroup tiling is happening first.
510+
cleanupPatterns.insert<linalg::ExtractSliceOfPadTensorSwapPattern>(
511+
context, [](tensor::ExtractSliceOp) { return /*zeroSliceGuard=*/false; });
506512
tileAndFuseOptions.cleanupPatterns =
507513
FrozenRewritePatternSet(std::move(cleanupPatterns));
508514

@@ -513,6 +519,9 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
513519
bool isDestinationOperand)
514520
-> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
515521
Operation *owner = originalProducer.getOwner();
522+
if (isa<tensor::PadOp>(owner)) {
523+
return std::nullopt;
524+
}
516525
bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
517526
return scf::SCFTileAndFuseOptions::ControlFnResult{
518527
yieldProducerReplacement};

compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,3 +803,23 @@ func.func @set_encoding_gpu(%arg0 : tensor<?x?xi8>) -> tensor<?x?x8x4x4x4x2x8xi8
803803
// CHECK: tensor.expand_shape
804804
// CHECK: linalg.generic
805805
// CHECK: tensor.parallel_insert_slice
806+
807+
// -----
808+
809+
func.func @pad_fusion(%0 : tensor<?x?xf32>, %1 : tensor<?x?xf32>, %2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
810+
%cst = arith.constant 0.000000e+00 : f32
811+
%padded = tensor.pad %0 low[1, 1] high[1, 1] {
812+
^bb0(%arg0: index, %arg1: index):
813+
tensor.yield %cst : f32
814+
} : tensor<?x?xf32> to tensor<?x?xf32>
815+
%3 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 0]]>}
816+
ins(%padded, %1 : tensor<?x?xf32>, tensor<?x?xf32>)
817+
outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
818+
return %3 : tensor<?x?xf32>
819+
}
820+
821+
// CHECK-LABEL: func @pad_fusion(
822+
// CHECK: %[[RESULT:.+]] = scf.forall (%[[ID0:.+]], %[[ID1:.+]])
823+
// CHECK: %[[PADDED:.+]] = tensor.pad
824+
// CHECK: %[[MATMUL:.+]] = linalg.matmul
825+
// CHECK-SAME: ins(%[[PADDED]]

0 commit comments

Comments
 (0)