Skip to content

Commit a6a0b48

Browse files
committed
[mlir][NVVM] Add support for barrier0 operation with predicate
1 parent 873b8d5 commit a6a0b48

File tree

3 files changed

+63
-7
lines changed

3 files changed

+63
-7
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,55 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
977977
}];
978978
}
979979

980+
// Attrs describing the predicate of barrier0 operation.
981+
def Barrier0PredPopc : I32EnumAttrCase<"POPC", 0, "popc">;
982+
def Barrier0PredAnd : I32EnumAttrCase<"AND", 1, "and">;
983+
def Barrier0PredOr : I32EnumAttrCase<"OR", 2, "or">;
984+
985+
def Barrier0Pred
986+
: I32EnumAttr<"Barrier0Pred", "NVVM barrier0 predicate",
987+
[Barrier0PredPopc, Barrier0PredAnd, Barrier0PredOr]> {
988+
let genSpecializedAttr = 0;
989+
let cppNamespace = "::mlir::NVVM";
990+
}
991+
def Barrier0PredAttr : EnumAttr<NVVM_Dialect, Barrier0Pred, "barrier0_pred"> {
992+
let assemblyFormat = "`<` $value `>`";
993+
}
994+
995+
def NVVM_Barrier0PredOp : NVVM_Op<"barrier0.pred">,
996+
Arguments<(ins Barrier0PredAttr:$pred, I32:$value)>,
997+
Results<(outs I32:$res)> {
998+
let summary = "CTA Barrier Synchronization with predicate (Barrier ID 0)";
999+
let description = [{
1000+
The `nvvm.barrier0` operation is a convenience operation that performs
1001+
barrier synchronization and communication within a CTA
1002+
(Cooperative Thread Array) using barrier ID 0. It is functionally
1003+
equivalent to `nvvm.barrier` or `nvvm.barrier id=0`.
1004+
1005+
`popc` is identical to `nvvm.barrier0` with the additional feature that it
1006+
evaluates predicate for all threads of the block and returns the number of
1007+
threads for which predicate evaluates to non-zero.
1008+
1009+
`and` is identical to `nvvm.barrier0` with the additional feature that it
1010+
evaluates predicate for all threads of the block and returns non-zero if
1011+
and only if predicate evaluates to non-zero for all of them.
1012+
1013+
`or` is identical to `nvvm.barrier0` with the additional feature that it
1014+
evaluates predicate for all threads of the block and returns non-zero if and
1015+
only if predicate evaluates to non-zero for any of them.
1016+
1017+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
1018+
}];
1019+
1020+
let assemblyFormat =
1021+
" ($value^ `:` type($value))? ($pred^)? attr-dict `->` type($res)";
1022+
string llvmBuilder = [{
1023+
createIntrinsicCall(
1024+
builder, getBarrier0IntrinsicID($pred),
1025+
{$value ? $value : builder.getInt32(0)});
1026+
}];
1027+
}
1028+
9801029
def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
9811030
let summary = "CTA Barrier Synchronization Op";
9821031
let description = [{

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,20 @@ static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
291291
llvm_unreachable("Unsupported proxy kinds");
292292
}
293293

294+
static unsigned getBarrier0IntrinsicID(std::optional<NVVM::Barrier0Pred> pred) {
295+
if (!pred)
296+
return llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
297+
switch (*pred) {
298+
case NVVM::Barrier0Pred::AND:
299+
return llvm::Intrinsic::nvvm_barrier0_and;
300+
case NVVM::Barrier0Pred::OR:
301+
return llvm::Intrinsic::nvvm_barrier0_or;
302+
case NVVM::Barrier0Pred::POPC:
303+
return llvm::Intrinsic::nvvm_barrier0_popc;
304+
}
305+
llvm_unreachable("Unknown predicate for barrier0");
306+
}
307+
294308
static unsigned getMembarIntrinsicID(NVVM::MemScopeKind scope) {
295309
switch (scope) {
296310
case NVVM::MemScopeKind::CTA:

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,6 @@ llvm.func @nvvm_rcp(%0: f32) -> f32 {
166166
llvm.return %1 : f32
167167
}
168168

169-
// CHECK-LABEL: @llvm_nvvm_barrier0
170-
llvm.func @llvm_nvvm_barrier0() {
171-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
172-
nvvm.barrier0
173-
llvm.return
174-
}
175-
176169
// CHECK-LABEL: @llvm_nvvm_barrier(
177170
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]])
178171
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32) {

0 commit comments

Comments
 (0)