Skip to content

Commit f88d060

Browse files
authored
[mlir][amdgpu] memory_counter_wait tensor counter support (#171153)
1 parent f27fbca commit f88d060

File tree

8 files changed

+47
-16
lines changed

8 files changed

+47
-16
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,8 @@ def AMDGPU_MemoryCounterWaitOp :
906906
OptionalAttr<I32Attr>:$load,
907907
OptionalAttr<I32Attr>:$store,
908908
OptionalAttr<I32Attr>:$ds,
909-
OptionalAttr<I32Attr>:$exp
909+
OptionalAttr<I32Attr>:$exp,
910+
OptionalAttr<I32Attr>:$tensor
910911
)>
911912
{
912913
let summary = "Wait for specified hardware counters";
@@ -919,7 +920,7 @@ def AMDGPU_MemoryCounterWaitOp :
919920
counters into one.
920921
}];
921922
let assemblyFormat = [{
922-
oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` | `exp` `(` $exp `)` ) attr-dict
923+
oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` | `exp` `(` $exp `)` | `tensor` `(` $tensor `)` ) attr-dict
923924
}];
924925

925926
let hasCanonicalizer = 1;

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,10 +506,16 @@ struct MemoryCounterWaitOpLowering
506506
if (std::optional<int> exp = adaptor.getExp())
507507
ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
508508

509+
if (std::optional<int> tensor = adaptor.getTensor())
510+
ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
511+
509512
rewriter.eraseOp(op);
510513
return success();
511514
}
512515

516+
if (adaptor.getTensor())
517+
return op.emitOpError("unsupported chipset");
518+
513519
auto getVal = [](Attribute attr) -> unsigned {
514520
if (attr)
515521
return cast<IntegerAttr>(attr).getInt();

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -614,10 +614,12 @@ struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
614614

615615
auto setters = {&MemoryCounterWaitOp::setLoad,
616616
&MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
617-
&MemoryCounterWaitOp::setExp};
618-
auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp()};
617+
&MemoryCounterWaitOp::setExp,
618+
&MemoryCounterWaitOp::setTensor};
619+
auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
620+
op.getTensor()};
619621
auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
620-
next.getExp()};
622+
next.getExp(), next.getTensor()};
621623
rewriter.modifyOpInPlace(op, [&] {
622624
for (auto [setter, lhs, rhs] :
623625
llvm::zip_equal(setters, lhsVals, rhsVals)) {

mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9
2-
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10
3-
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11
4-
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12
1+
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9
2+
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10
3+
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11
4+
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12
55

66
// CHECK-LABEL: func @memory_counter_wait
77
func.func @memory_counter_wait() {
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 | FileCheck %s
2+
3+
// CHECK-LABEL: func @memory_counter_wait_tensor
4+
func.func @memory_counter_wait_tensor() {
5+
// CHECK: rocdl.s.wait.tensorcnt 3
6+
amdgpu.memory_counter_wait tensor(3)
7+
8+
return
9+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: mlir-opt %s --verify-diagnostics --convert-amdgpu-to-rocdl=chipset=gfx942
2+
// RUN: mlir-opt %s --verify-diagnostics --convert-amdgpu-to-rocdl=chipset=gfx1030
3+
// RUN: mlir-opt %s --verify-diagnostics --convert-amdgpu-to-rocdl=chipset=gfx1100
4+
5+
func.func @memory_counter_wait_tensor() {
6+
// expected-error @below{{failed to legalize operation 'amdgpu.memory_counter_wait'}}
7+
// expected-error @below{{'amdgpu.memory_counter_wait' op unsupported chipset}}
8+
amdgpu.memory_counter_wait tensor(0)
9+
10+
return
11+
}

mlir/test/Dialect/AMDGPU/canonicalize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,10 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
250250
// CHECK-LABEL fuse_memory_counter_wait
251251
func.func @fuse_memory_counter_wait() {
252252
// CHECK: amdgpu.memory_counter_wait
253-
// CHECK-SAME: load(1) store(2) ds(2) exp(1)
253+
// CHECK-SAME: load(1) store(2) ds(2) exp(1) tensor(0)
254254
// CHECK-NEXT: return
255-
amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
256-
amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1)
255+
amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5)
256+
amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1) tensor(0)
257257
return
258258
}
259259

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -671,18 +671,20 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, %
671671

672672
// CHECK-LABEL: func @memory_counter_wait
673673
func.func @memory_counter_wait() {
674-
// CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
675-
// CHECK: amdgpu.memory_counter_wait load(4) store(2) ds(3) exp(1)
674+
// CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5)
675+
// CHECK: amdgpu.memory_counter_wait load(4) store(2) ds(3) exp(1) tensor(0)
676676
// CHECK: amdgpu.memory_counter_wait load(1)
677677
// CHECK: amdgpu.memory_counter_wait store(2)
678678
// CHECK: amdgpu.memory_counter_wait ds(3)
679679
// CHECK: amdgpu.memory_counter_wait exp(4)
680-
amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
681-
amdgpu.memory_counter_wait exp(1) store(2) ds(3) load(4)
680+
// CHECK: amdgpu.memory_counter_wait tensor(5)
681+
amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5)
682+
amdgpu.memory_counter_wait tensor(0) exp(1) store(2) ds(3) load(4)
682683
amdgpu.memory_counter_wait load(1)
683684
amdgpu.memory_counter_wait store(2)
684685
amdgpu.memory_counter_wait ds(3)
685686
amdgpu.memory_counter_wait exp(4)
687+
amdgpu.memory_counter_wait tensor(5)
686688
func.return
687689
}
688690

0 commit comments

Comments
 (0)