Skip to content

Commit ca8419d

Browse files
authored
[mlir][amdgpu] Fuse adjacent MemoryCounterWaitOp (#171148)
Taking the minimum value.
1 parent ebdb903 commit ca8419d

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,8 @@ def AMDGPU_MemoryCounterWaitOp :
921921
let assemblyFormat = [{
922922
oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` | `exp` `(` $exp `)` ) attr-dict
923923
}];
924+
925+
let hasCanonicalizer = 1;
924926
}
925927

926928
def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB",

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,51 @@ LogicalResult PermlaneSwapOp::verify() {
596596
return success();
597597
}
598598

599+
//===----------------------------------------------------------------------===//
600+
// MemoryCounterWaitOp
601+
//===----------------------------------------------------------------------===//
602+
603+
namespace {
604+
/// Fuse adjacent memory counter wait ops, taking the minimum value of the
605+
/// counters.
606+
struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
607+
using Base::Base;
608+
609+
LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
610+
PatternRewriter &rewriter) const override {
611+
auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
612+
if (!next)
613+
return failure();
614+
615+
auto setters = {&MemoryCounterWaitOp::setLoad,
616+
&MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
617+
&MemoryCounterWaitOp::setExp};
618+
auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp()};
619+
auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
620+
next.getExp()};
621+
rewriter.modifyOpInPlace(op, [&] {
622+
for (auto [setter, lhs, rhs] :
623+
llvm::zip_equal(setters, lhsVals, rhsVals)) {
624+
if (lhs && rhs) {
625+
(op.*setter)(std::min(*lhs, *rhs));
626+
} else if (lhs) {
627+
(op.*setter)(*lhs);
628+
} else if (rhs) {
629+
(op.*setter)(*rhs);
630+
}
631+
}
632+
});
633+
rewriter.eraseOp(next);
634+
return success();
635+
}
636+
};
637+
} // namespace
638+
639+
void MemoryCounterWaitOp::getCanonicalizationPatterns(
640+
RewritePatternSet &results, MLIRContext *context) {
641+
results.add<FuseMemoryCounterWaitOp>(context);
642+
}
643+
599644
//===----------------------------------------------------------------------===//
600645
// GatherToLDSOp
601646
//===----------------------------------------------------------------------===//

mlir/test/Dialect/AMDGPU/canonicalize.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,39 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
244244
%res_7 = amdgpu.scaled_mfma 16x16x128 (%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
245245
return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
246246
}
247+
248+
// -----
249+
250+
// CHECK-LABEL fuse_memory_counter_wait
251+
func.func @fuse_memory_counter_wait() {
252+
// CHECK: amdgpu.memory_counter_wait
253+
// CHECK-SAME: load(1) store(2) ds(2) exp(1)
254+
// CHECK-NEXT: return
255+
amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
256+
amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1)
257+
return
258+
}
259+
260+
// CHECK-LABEL fuse_memory_counter_wait_different_counters
261+
func.func @fuse_memory_counter_wait_different_counters() {
262+
// CHECK: amdgpu.memory_counter_wait
263+
// CHECK-SAME: load(1) store(2) ds(3) exp(4)
264+
// CHECK-NEXT: return
265+
amdgpu.memory_counter_wait load(1) store(2)
266+
amdgpu.memory_counter_wait ds(3) exp(4)
267+
return
268+
}
269+
270+
func.func private @use()
271+
272+
// CHECK-LABEL fuse_memory_counter_wait_not_adjacent
273+
func.func @fuse_memory_counter_wait_not_adjacent() {
274+
// CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
275+
// CHECK-NEXT: call @use()
276+
// CHECK-NEXT: amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1)
277+
// CHECK-NEXT: return
278+
amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
279+
func.call @use() : () -> ()
280+
amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1)
281+
return
282+
}

0 commit comments

Comments
 (0)