Skip to content

Commit ad68964

Browse files
authored
[LLVMGPU] Pad to intrinsic shape in LLVMGPUPadAndVectorDistribute pipeline (#18632)
This patch makes LLVMGPUPromoteToFitMMA pass pad to a multiple of intrinsic shape, instead of padding to 1. Fixes #18602
1 parent 6001f9c commit ad68964

File tree

3 files changed

+117
-22
lines changed

3 files changed

+117
-22
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,17 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
3636
}
3737

3838
void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op,
39-
utils::IteratorType targetIterType, bool nofold) const {
39+
ArrayRef<int64_t> paddingDims,
40+
ArrayRef<int64_t> padToMultipleOf, bool noFold) const {
41+
assert(paddingDims.size() == padToMultipleOf.size() &&
42+
"invalid pad multiples for padding dimensions");
43+
4044
LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n");
4145
OpBuilder::InsertionGuard guard(rewriter);
4246
rewriter.setInsertionPointAfter(op);
4347

44-
SmallVector<int64_t> paddingDims;
45-
for (auto [index, iterType] : llvm::enumerate(op.getIteratorTypesArray())) {
46-
if (iterType == targetIterType) {
47-
paddingDims.push_back(index);
48-
}
49-
}
50-
51-
SmallVector<bool> packPaddings(op.getNumDpsInputs(), nofold);
48+
SmallVector<bool> packPaddings(op.getNumDpsInputs(), noFold);
5249

53-
// One is enough because they will essentially be padded to corresponding
54-
// tile sizes, which should be multiple of MMA shapes.
55-
SmallVector<int64_t> padToMultipleOf(paddingDims.size(), 1);
5650
SmallVector<Attribute> paddingValueAttributes;
5751
for (auto &operand : op->getOpOperands()) {
5852
auto elemType = getElementTypeOrSelf(operand.get().getType());
@@ -80,18 +74,18 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
8074

8175
// Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so
8276
// we can kick canonicalization patterns to fold outer tensor.pad ops away.
83-
bool nofold = false;
77+
bool noFold = false;
8478
utils::IteratorType targetIterType = utils::IteratorType::parallel;
8579
switch (targetDimensions) {
8680
case LLVMGPUMatmulPadOption::ParallelDims:
8781
LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n");
8882
targetIterType = utils::IteratorType::parallel;
89-
nofold = false;
83+
noFold = false;
9084
break;
9185
case LLVMGPUMatmulPadOption::ReductionDims:
9286
LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n");
9387
targetIterType = utils::IteratorType::reduction;
94-
nofold = true;
88+
noFold = true;
9589
break;
9690
default: // Unreachable.
9791
assert(false);
@@ -106,8 +100,47 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
106100
});
107101

108102
IRRewriter rewriter(ctx);
109-
for (auto op : candidates) {
110-
padWithZeroValue(rewriter, op, targetIterType, nofold);
103+
for (linalg::LinalgOp op : candidates) {
104+
SmallVector<int64_t> padMultiples(op.getNumLoops(), 1);
105+
auto config = dyn_cast_or_null<IREE::GPU::LoweringConfigAttr>(
106+
getLoweringConfig(op));
107+
if (config) {
108+
switch (targetDimensions) {
109+
case LLVMGPUMatmulPadOption::ParallelDims:
110+
padMultiples = config.getStaticTilingLevelSizes(
111+
static_cast<unsigned>(IREE::GPU::TilingLevel::Workgroup), op);
112+
break;
113+
case LLVMGPUMatmulPadOption::ReductionDims:
114+
padMultiples = config.getStaticTilingLevelSizes(
115+
static_cast<unsigned>(IREE::GPU::TilingLevel::Reduction), op);
116+
break;
117+
default:
118+
assert(false && "Unexpected target dimensions");
119+
break;
120+
}
121+
}
122+
123+
// Populate padding dimensions.
124+
SmallVector<int64_t> paddingDimensions;
125+
for (auto [idx, iter] : llvm::enumerate(op.getIteratorTypesArray())) {
126+
if (iter == targetIterType) {
127+
paddingDimensions.push_back(idx);
128+
}
129+
}
130+
131+
// Populate tile sizes. We pad to multiples of workgroup/reduction
132+
// tile sizes based on the selected target tiling dimensions.
133+
// This pass is ran after the select target tiling is done to pad
134+
// all dimensions to the select tile sizes.
135+
SmallVector<int64_t> padToMultipleOf;
136+
for (int64_t dim : paddingDimensions) {
137+
if (padMultiples[dim] != 0) {
138+
padToMultipleOf.push_back(padMultiples[dim]);
139+
}
140+
}
141+
142+
padWithZeroValue(rewriter, op, paddingDimensions, padToMultipleOf,
143+
noFold);
111144
}
112145

113146
{

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,67 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
484484

485485
// -----
486486

487+
#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 32, 0], reduction = [0, 0, 0, 8]}>
488+
#translation = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>, subgroup_m_count = 1, subgroup_n_count = 2>}>
489+
490+
#pipeline_layout = #hal.pipeline.layout<bindings = [
491+
#hal.pipeline.binding<storage_buffer>,
492+
#hal.pipeline.binding<storage_buffer>,
493+
#hal.pipeline.binding<storage_buffer>
494+
]>
495+
496+
hal.executable public @pad_batch_matmul {
497+
hal.executable.variant public @rocm_hsaco_fb target(#hal.executable.target<"rocm", "rocm-hsaco-fb">) {
498+
hal.executable.export public @pad_batch_matmul ordinal(0) layout(#pipeline_layout) {
499+
^bb0(%arg0: !hal.device):
500+
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
501+
hal.return %x, %y, %z : index, index, index
502+
}
503+
builtin.module {
504+
func.func @pad_batch_matmul() attributes {translation_info = #translation} {
505+
%cst = arith.constant 0.000000e+00 : f32
506+
%c0 = arith.constant 0 : index
507+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<196x16x24xf32>>
508+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<196x24x24xf32>>
509+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<196x16x24xf32>>
510+
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [196, 16, 24], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<196x16x24xf32>> -> tensor<196x16x24xf32>
511+
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [196, 24, 24], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<196x24x24xf32>> -> tensor<196x24x24xf32>
512+
%5 = tensor.empty() : tensor<196x16x24xf32>
513+
%6 = linalg.fill {lowering_config = #config} ins(%cst : f32) outs(%5 : tensor<196x16x24xf32>) -> tensor<196x16x24xf32>
514+
%7 = linalg.batch_matmul {lowering_config = #config} ins(%3, %4 : tensor<196x16x24xf32>, tensor<196x24x24xf32>) outs(%6 : tensor<196x16x24xf32>) -> tensor<196x16x24xf32>
515+
flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [196, 16, 24], strides = [1, 1, 1] : tensor<196x16x24xf32> -> !flow.dispatch.tensor<writeonly:tensor<196x16x24xf32>>
516+
return
517+
}
518+
}
519+
}
520+
}
521+
522+
// This test checks if we can handle an unaligned batch matmul which has sizes
523+
// smaller than the chosen tile sizes. We just want to make sure we can compile
524+
// this example. We also check if the correct transfer_read/transfer_write are
525+
// produced with in_bounds attrs for the padded dimensions.
526+
527+
// CHECK-LABEL: @pad_batch_matmul
528+
// CHECK: scf.for
529+
// LHS
530+
// CHECK: vector.transfer_read
531+
// CHECK-SAME: in_bounds = [true, true, true]
532+
// CHECK-SAME: memref<196x16x24xf32
533+
// CHECK-SAME: vector<1x1x1xf32>
534+
// RHS
535+
// CHECK: vector.transfer_read
536+
// CHECK-SAME: in_bounds = [true, true, false]
537+
// CHECK-SAME: memref<1x8x24xf32
538+
// CHECK-SAME: vector<1x1x2xf32>
539+
// CHECK: scf.yield
540+
// OUTPUT
541+
// CHECK: vector.transfer_write
542+
// CHECK-SAME: in_bounds = [true, true, false]
543+
// CHECK-SAME: vector<1x4x1xf32>
544+
// CHECK-SAME: memref<1x16x24xf32
545+
546+
// -----
547+
487548
// This test ensures that we are generating contraction schedules does not only work on contraction,
488549
// but also will be compatible with transfer_read layouts anchors.
489550
// Currently the transfer_read layout anchors expects WorkgroupSize % (WgTileSize / numelPerThread) == 0.

compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#map3 = affine_map<()[s0] -> (s0 * -128 + 1281, 128)>
1313
#map4 = affine_map<()[s0] -> (-s0 + 64)>
1414
#map5 = affine_map<()[s0] -> (-s0 + 128)>
15+
#config = #iree_gpu.lowering_config<{workgroup = [1, 16, 16, 0], reduction = [0, 0, 0, 16]}>
1516
func.func @batch_matmul_f16() {
1617
%cst = arith.constant 0.000000e+00 : f16
1718
%c0 = arith.constant 0 : index
@@ -29,7 +30,7 @@ func.func @batch_matmul_f16() {
2930
%8 = flow.dispatch.tensor.load %0, offsets = [%workgroup_id_z, %3, 0], sizes = [1, %5, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>> -> tensor<1x?x1281xf16>
3031
%9 = flow.dispatch.tensor.load %1, offsets = [%workgroup_id_z, 0, %4], sizes = [1, 1281, %6], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>> -> tensor<1x1281x?xf16>
3132
%10 = linalg.fill ins(%cst : f16) outs(%7 : tensor<1x?x?xf16>) -> tensor<1x?x?xf16>
32-
%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x?x1281xf16>, tensor<1x1281x?xf16>) outs(%10 : tensor<1x?x?xf16>) -> tensor<1x?x?xf16>
33+
%11 = linalg.batch_matmul {lowering_config = #config} ins(%8, %9 : tensor<1x?x1281xf16>, tensor<1x1281x?xf16>) outs(%10 : tensor<1x?x?xf16>) -> tensor<1x?x?xf16>
3334
flow.dispatch.tensor.store %11, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
3435
return
3536
}
@@ -48,14 +49,14 @@ func.func @batch_matmul_f16() {
4849
// PARALLEL-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
4950
// PARALLEL-SAME: outs(%[[FILL]]
5051

51-
// The reduction dim is not tiled in the test case, so it pads it to the same
52-
// shape.
52+
// The reduction dim is not tiled in the test case, so it pads it to the
53+
// matmul intrinsic k.
5354
// REDUCTION-DAG: %[[FILL_DEST:.+]] = flow.dispatch.tensor.load %[[OUT_HANDLE]]
5455
// REDUCTION: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[FILL_DEST]]
5556
// REDUCTION: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]]
56-
// REDUCTION: } : tensor<1x?x1281xf16> to tensor<1x?x1281xf16>
57+
// REDUCTION: } : tensor<1x?x1281xf16> to tensor<1x?x1296xf16>
5758
// REDUCTION: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]]
58-
// REDUCTION: } : tensor<1x1281x?xf16> to tensor<1x1281x?xf16>
59+
// REDUCTION: } : tensor<1x1281x?xf16> to tensor<1x1296x?xf16>
5960
// REDUCTION: %[[GEMM:.+]] = linalg.batch_matmul
6061
// REDUCTION-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
6162
// REDUCTION-SAME: outs(%[[FILL]]

0 commit comments

Comments
 (0)