Skip to content

Commit 8efa918

Browse files
committed
[MLIR][NVVM] Update mbarrier Ops to use AnyTypeOf[] (3/3)
This patch updates the remaining two Ops to use the AnyTypeOf[] construct, completing the migration for the mbarrier family of Ops. Signed-off-by: Durgadoss R <[email protected]>
1 parent 321afc7 commit 8efa918

File tree

5 files changed

+52
-96
lines changed

5 files changed

+52
-96
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 6 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,9 @@ def NVVM_MBarrierArriveNocompleteOp : NVVM_Op<"mbarrier.arrive.nocomplete">,
745745
}
746746

747747
def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
748-
Arguments<(ins LLVM_AnyPointer:$addr, I32:$txcount, PtxPredicate:$predicate)> {
748+
Arguments<(ins
749+
AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
750+
I32:$txcount, PtxPredicate:$predicate)> {
749751
let summary = "MBarrier Arrive with Expected Transaction Count";
750752
let description = [{
751753
The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation
@@ -773,28 +775,12 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
773775
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
774776
}];
775777
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
776-
let extraClassDefinition = [{
777-
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
778-
}];
779-
}
780-
781-
def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
782-
Arguments<(ins LLVM_PointerShared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
783-
let summary = "Shared MBarrier Arrive with Expected Transaction Count";
784-
let description = [{
785-
This Op is the same as `nvvm.mbarrier.arrive.expect_tx` except that the *mbarrier object*
786-
should be accessed using a shared-memory pointer instead of a generic-memory pointer.
787-
788-
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
789-
}];
790-
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
791-
let extraClassDefinition = [{
792-
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
793-
}];
794778
}
795779

796780
def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
797-
Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
781+
Arguments<(ins
782+
AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
783+
I32:$phase, I32:$ticks)> {
798784
let summary = "MBarrier Potentially-Blocking Try Wait with Phase Parity";
799785
let description = [{
800786
The `nvvm.mbarrier.try_wait.parity` operation performs a potentially-blocking
@@ -847,46 +833,6 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity"
847833
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
848834
}];
849835
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
850-
let extraClassDefinition = [{
851-
std::string $cppClass::getPtx() {
852-
return std::string(
853-
"{\n\t"
854-
".reg .pred P1; \n\t"
855-
"LAB_WAIT: \n\t"
856-
"mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
857-
"@P1 bra.uni DONE; \n\t"
858-
"bra.uni LAB_WAIT; \n\t"
859-
"DONE: \n\t"
860-
"}"
861-
);
862-
}
863-
}];
864-
}
865-
866-
def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
867-
Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
868-
let summary = "Shared MBarrier Potentially-Blocking Try Wait with Phase Parity";
869-
let description = [{
870-
This Op is the same as `nvvm.mbarrier.try_wait.parity` except that the *mbarrier object*
871-
should be accessed using a shared-memory pointer instead of a generic-memory pointer.
872-
873-
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
874-
}];
875-
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
876-
let extraClassDefinition = [{
877-
std::string $cppClass::getPtx() {
878-
return std::string(
879-
"{\n\t"
880-
".reg .pred P1; \n\t"
881-
"LAB_WAIT: \n\t"
882-
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
883-
"@P1 bra.uni DONE; \n\t"
884-
"bra.uni LAB_WAIT; \n\t"
885-
"DONE: \n\t"
886-
"}"
887-
);
888-
}
889-
}];
890836
}
891837

892838
def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -922,13 +922,6 @@ struct NVGPUMBarrierArriveExpectTxLowering
922922
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
923923
adaptor.getMbarId(), rewriter);
924924
Value txcount = truncToI32(b, adaptor.getTxcount());
925-
926-
if (isMbarrierShared(op.getBarriers().getType())) {
927-
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
928-
op, barrier, txcount, adaptor.getPredicate());
929-
return success();
930-
}
931-
932925
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
933926
op, barrier, txcount, adaptor.getPredicate());
934927
return success();
@@ -949,13 +942,6 @@ struct NVGPUMBarrierTryWaitParityLowering
949942
Value ticks = truncToI32(b, adaptor.getTicks());
950943
Value phase =
951944
LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
952-
953-
if (isMbarrierShared(op.getBarriers().getType())) {
954-
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
955-
op, barrier, phase, ticks);
956-
return success();
957-
}
958-
959945
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
960946
phase, ticks);
961947
return success();

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ using namespace NVVM;
4747

4848
static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
4949

50+
//===----------------------------------------------------------------------===//
51+
// Helper/Utility methods
52+
//===----------------------------------------------------------------------===//
53+
54+
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
55+
auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
56+
return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
57+
}
58+
59+
static bool isPtrInSharedCTASpace(mlir::Value ptr) {
60+
return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
61+
}
62+
5063
//===----------------------------------------------------------------------===//
5164
// Verifier methods
5265
//===----------------------------------------------------------------------===//
@@ -1741,26 +1754,37 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
17411754
//===----------------------------------------------------------------------===//
17421755

17431756
std::string NVVM::MBarrierInitOp::getPtx() {
1744-
unsigned addressSpace =
1745-
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1746-
return (addressSpace == NVVMMemorySpace::Shared)
1747-
? std::string("mbarrier.init.shared.b64 [%0], %1;")
1748-
: std::string("mbarrier.init.b64 [%0], %1;");
1757+
bool isShared = isPtrInSharedCTASpace(getAddr());
1758+
return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
1759+
: std::string("mbarrier.init.b64 [%0], %1;");
17491760
}
17501761

1751-
//===----------------------------------------------------------------------===//
1752-
// getIntrinsicID/getIntrinsicIDAndArgs methods
1753-
//===----------------------------------------------------------------------===//
1754-
1755-
static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
1756-
auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
1757-
return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
1762+
std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
1763+
bool isShared = isPtrInSharedCTASpace(getAddr());
1764+
return isShared
1765+
? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
1766+
: std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
17581767
}
17591768

1760-
static bool isPtrInSharedCTASpace(mlir::Value ptr) {
1761-
return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
1769+
std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
1770+
bool isShared = isPtrInSharedCTASpace(getAddr());
1771+
llvm::StringRef space = isShared ? ".shared" : "";
1772+
1773+
return llvm::formatv("{\n\t"
1774+
".reg .pred P1; \n\t"
1775+
"LAB_WAIT: \n\t"
1776+
"mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
1777+
"@P1 bra.uni DONE; \n\t"
1778+
"bra.uni LAB_WAIT; \n\t"
1779+
"DONE: \n\t"
1780+
"}",
1781+
space);
17621782
}
17631783

1784+
//===----------------------------------------------------------------------===//
1785+
// getIntrinsicID/getIntrinsicIDAndArgs methods
1786+
//===----------------------------------------------------------------------===//
1787+
17641788
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
17651789
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
17661790
auto thisOp = cast<NVVM::MBarrierInitOp>(op);

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -603,14 +603,14 @@ func.func @mbarrier_txcount() {
603603
%txcount = arith.constant 256 : index
604604
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
605605
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
606-
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
606+
// CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
607607
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
608608
scf.yield
609609
} else {
610610
%txcount = arith.constant 0 : index
611611
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
612612
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
613-
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
613+
// CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
614614
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
615615
scf.yield
616616
}
@@ -620,7 +620,7 @@ func.func @mbarrier_txcount() {
620620
%ticks = arith.constant 10000000 : index
621621
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
622622
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
623-
// CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
623+
// CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
624624
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
625625

626626
func.return
@@ -649,14 +649,14 @@ func.func @mbarrier_txcount_pred() {
649649
%txcount = arith.constant 256 : index
650650
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
651651
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
652-
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
652+
// CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]], {{.*}}, predicate = %[[P]]
653653
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
654654

655655
%phase_c0 = arith.constant 0 : i1
656656
%ticks = arith.constant 10000000 : index
657657
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
658658
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
659-
// CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
659+
// CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
660660
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
661661

662662
func.return

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %cou
1717
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
1818
llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
1919
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
20-
nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
20+
nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32
2121
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b"
22-
nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
22+
nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
2323
llvm.return
2424
}
2525

@@ -44,7 +44,7 @@ llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32,
4444
// CHECK-SAME: DONE:
4545
// CHECK-SAME: }",
4646
// CHECK-SAME: "r,r,r"
47-
nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
47+
nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
4848
llvm.return
4949
}
5050

0 commit comments

Comments
 (0)