Skip to content

Commit b58ada0

Browse files
durga4githubgit-crd
authored andcommitted
[MLIR][NVVM] Update mbarrier Ops to use AnyTypeOf[] (3/3) (llvm#167567)
This is a follow-up of PR llvm#165558 and llvm#165993. This patch updates the remaining two Ops to use the AnyTypeOf[] construct, completing the migration for the mbarrier family of Ops. ``` mbarrier.arrive.expect_tx mbarrier.try_wait.parity ``` Signed-off-by: Durgadoss R <[email protected]>
1 parent 0bc798e commit b58ada0

File tree

5 files changed

+52
-96
lines changed

5 files changed

+52
-96
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 6 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,9 @@ def NVVM_MBarrierArriveNocompleteOp : NVVM_Op<"mbarrier.arrive.nocomplete">,
743743
}
744744

745745
def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
746-
Arguments<(ins LLVM_AnyPointer:$addr, I32:$txcount, PtxPredicate:$predicate)> {
746+
Arguments<(ins
747+
AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
748+
I32:$txcount, PtxPredicate:$predicate)> {
747749
let summary = "MBarrier Arrive with Expected Transaction Count";
748750
let description = [{
749751
The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation
@@ -771,28 +773,12 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
771773
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
772774
}];
773775
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
774-
let extraClassDefinition = [{
775-
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
776-
}];
777-
}
778-
779-
def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
780-
Arguments<(ins LLVM_PointerShared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
781-
let summary = "Shared MBarrier Arrive with Expected Transaction Count";
782-
let description = [{
783-
This Op is the same as `nvvm.mbarrier.arrive.expect_tx` except that the *mbarrier object*
784-
should be accessed using a shared-memory pointer instead of a generic-memory pointer.
785-
786-
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
787-
}];
788-
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
789-
let extraClassDefinition = [{
790-
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
791-
}];
792776
}
793777

794778
def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
795-
Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
779+
Arguments<(ins
780+
AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
781+
I32:$phase, I32:$ticks)> {
796782
let summary = "MBarrier Potentially-Blocking Try Wait with Phase Parity";
797783
let description = [{
798784
The `nvvm.mbarrier.try_wait.parity` operation performs a potentially-blocking
@@ -845,46 +831,6 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity"
845831
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
846832
}];
847833
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
848-
let extraClassDefinition = [{
849-
std::string $cppClass::getPtx() {
850-
return std::string(
851-
"{\n\t"
852-
".reg .pred P1; \n\t"
853-
"LAB_WAIT: \n\t"
854-
"mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
855-
"@P1 bra.uni DONE; \n\t"
856-
"bra.uni LAB_WAIT; \n\t"
857-
"DONE: \n\t"
858-
"}"
859-
);
860-
}
861-
}];
862-
}
863-
864-
def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
865-
Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
866-
let summary = "Shared MBarrier Potentially-Blocking Try Wait with Phase Parity";
867-
let description = [{
868-
This Op is the same as `nvvm.mbarrier.try_wait.parity` except that the *mbarrier object*
869-
should be accessed using a shared-memory pointer instead of a generic-memory pointer.
870-
871-
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
872-
}];
873-
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
874-
let extraClassDefinition = [{
875-
std::string $cppClass::getPtx() {
876-
return std::string(
877-
"{\n\t"
878-
".reg .pred P1; \n\t"
879-
"LAB_WAIT: \n\t"
880-
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
881-
"@P1 bra.uni DONE; \n\t"
882-
"bra.uni LAB_WAIT; \n\t"
883-
"DONE: \n\t"
884-
"}"
885-
);
886-
}
887-
}];
888834
}
889835

890836
def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -922,13 +922,6 @@ struct NVGPUMBarrierArriveExpectTxLowering
922922
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
923923
adaptor.getMbarId(), rewriter);
924924
Value txcount = truncToI32(b, adaptor.getTxcount());
925-
926-
if (isMbarrierShared(op.getBarriers().getType())) {
927-
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
928-
op, barrier, txcount, adaptor.getPredicate());
929-
return success();
930-
}
931-
932925
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
933926
op, barrier, txcount, adaptor.getPredicate());
934927
return success();
@@ -949,13 +942,6 @@ struct NVGPUMBarrierTryWaitParityLowering
949942
Value ticks = truncToI32(b, adaptor.getTicks());
950943
Value phase =
951944
LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
952-
953-
if (isMbarrierShared(op.getBarriers().getType())) {
954-
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
955-
op, barrier, phase, ticks);
956-
return success();
957-
}
958-
959945
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
960946
phase, ticks);
961947
return success();

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ using namespace NVVM;
4747

4848
static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
4949

50+
//===----------------------------------------------------------------------===//
51+
// Helper/Utility methods
52+
//===----------------------------------------------------------------------===//
53+
54+
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
55+
auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
56+
return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
57+
}
58+
59+
static bool isPtrInSharedCTASpace(mlir::Value ptr) {
60+
return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
61+
}
62+
5063
//===----------------------------------------------------------------------===//
5164
// Verifier methods
5265
//===----------------------------------------------------------------------===//
@@ -1741,26 +1754,37 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
17411754
//===----------------------------------------------------------------------===//
17421755

17431756
std::string NVVM::MBarrierInitOp::getPtx() {
1744-
unsigned addressSpace =
1745-
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1746-
return (addressSpace == NVVMMemorySpace::Shared)
1747-
? std::string("mbarrier.init.shared.b64 [%0], %1;")
1748-
: std::string("mbarrier.init.b64 [%0], %1;");
1757+
bool isShared = isPtrInSharedCTASpace(getAddr());
1758+
return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
1759+
: std::string("mbarrier.init.b64 [%0], %1;");
17491760
}
17501761

1751-
//===----------------------------------------------------------------------===//
1752-
// getIntrinsicID/getIntrinsicIDAndArgs methods
1753-
//===----------------------------------------------------------------------===//
1754-
1755-
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
1756-
auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
1757-
return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
1762+
std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
1763+
bool isShared = isPtrInSharedCTASpace(getAddr());
1764+
return isShared
1765+
? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
1766+
: std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
17581767
}
17591768

1760-
static bool isPtrInSharedCTASpace(mlir::Value ptr) {
1761-
return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
1769+
std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
1770+
bool isShared = isPtrInSharedCTASpace(getAddr());
1771+
llvm::StringRef space = isShared ? ".shared" : "";
1772+
1773+
return llvm::formatv("{\n\t"
1774+
".reg .pred P1; \n\t"
1775+
"LAB_WAIT: \n\t"
1776+
"mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
1777+
"@P1 bra.uni DONE; \n\t"
1778+
"bra.uni LAB_WAIT; \n\t"
1779+
"DONE: \n\t"
1780+
"}",
1781+
space);
17621782
}
17631783

1784+
//===----------------------------------------------------------------------===//
1785+
// getIntrinsicID/getIntrinsicIDAndArgs methods
1786+
//===----------------------------------------------------------------------===//
1787+
17641788
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
17651789
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
17661790
auto thisOp = cast<NVVM::MBarrierInitOp>(op);

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -603,14 +603,14 @@ func.func @mbarrier_txcount() {
603603
%txcount = arith.constant 256 : index
604604
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
605605
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
606-
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
606+
// CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
607607
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
608608
scf.yield
609609
} else {
610610
%txcount = arith.constant 0 : index
611611
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
612612
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
613-
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
613+
// CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
614614
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
615615
scf.yield
616616
}
@@ -620,7 +620,7 @@ func.func @mbarrier_txcount() {
620620
%ticks = arith.constant 10000000 : index
621621
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
622622
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
623-
// CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
623+
// CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
624624
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
625625

626626
func.return
@@ -649,14 +649,14 @@ func.func @mbarrier_txcount_pred() {
649649
%txcount = arith.constant 256 : index
650650
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
651651
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
652-
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
652+
// CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]], {{.*}}, predicate = %[[P]]
653653
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
654654

655655
%phase_c0 = arith.constant 0 : i1
656656
%ticks = arith.constant 10000000 : index
657657
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
658658
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
659-
// CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
659+
// CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
660660
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
661661

662662
func.return

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %cou
1717
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
1818
llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
1919
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
20-
nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
20+
nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32
2121
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b"
22-
nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
22+
nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
2323
llvm.return
2424
}
2525

@@ -44,7 +44,7 @@ llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32,
4444
// CHECK-SAME: DONE:
4545
// CHECK-SAME: }",
4646
// CHECK-SAME: "r,r,r"
47-
nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
47+
nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
4848
llvm.return
4949
}
5050

0 commit comments

Comments
 (0)