Skip to content

Commit f7ce92f

Browse files
committed
Merge with nvvm.barrier
1 parent 2a91932 commit f7ce92f

File tree

6 files changed

+70
-89
lines changed

6 files changed

+70
-89
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -977,54 +977,23 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
977977
}];
978978
}
979979

980-
// Attrs describing the predicate of barrier0 operation.
981-
def Barrier0PredPopc : I32EnumAttrCase<"POPC", 0, "popc">;
982-
def Barrier0PredAnd : I32EnumAttrCase<"AND", 1, "and">;
983-
def Barrier0PredOr : I32EnumAttrCase<"OR", 2, "or">;
984-
985-
def Barrier0Pred
986-
: I32EnumAttr<"Barrier0Pred", "NVVM barrier0 predicate",
987-
[Barrier0PredPopc, Barrier0PredAnd, Barrier0PredOr]> {
980+
// Attrs describing the reduction operations for the barrier operation.
981+
def BarrierReductionPopc : I32EnumAttrCase<"POPC", 0, "popc">;
982+
def BarrierReductionAnd : I32EnumAttrCase<"AND", 1, "and">;
983+
def BarrierReductionOr : I32EnumAttrCase<"OR", 2, "or">;
984+
985+
def BarrierReduction
986+
: I32EnumAttr<"BarrierReduction", "NVVM barrier reduction operation",
987+
[BarrierReductionPopc, BarrierReductionAnd,
988+
BarrierReductionOr]> {
988989
let genSpecializedAttr = 0;
989990
let cppNamespace = "::mlir::NVVM";
990991
}
991-
def Barrier0PredAttr : EnumAttr<NVVM_Dialect, Barrier0Pred, "barrier0_pred"> {
992+
def BarrierReductionAttr
993+
: EnumAttr<NVVM_Dialect, BarrierReduction, "reduction"> {
992994
let assemblyFormat = "`<` $value `>`";
993995
}
994996

995-
def NVVM_Barrier0PredOp : NVVM_Op<"barrier0.pred">,
996-
Arguments<(ins Barrier0PredAttr:$pred, I32:$value)>,
997-
Results<(outs I32:$res)> {
998-
let summary = "CTA Barrier Synchronization with predicate (Barrier ID 0)";
999-
let description = [{
1000-
The `nvvm.barrier0` operation is a convenience operation that performs
1001-
barrier synchronization and communication within a CTA
1002-
(Cooperative Thread Array) using barrier ID 0. It is functionally
1003-
equivalent to `nvvm.barrier` or `nvvm.barrier id=0`.
1004-
1005-
`popc` is identical to `nvvm.barrier0` with the additional feature that it
1006-
evaluates predicate for all threads of the block and returns the number of
1007-
threads for which predicate evaluates to non-zero.
1008-
1009-
`and` is identical to `nvvm.barrier0` with the additional feature that it
1010-
evaluates predicate for all threads of the block and returns non-zero if
1011-
and only if predicate evaluates to non-zero for all of them.
1012-
1013-
`or` is identical to `nvvm.barrier0` with the additional feature that it
1014-
evaluates predicate for all threads of the block and returns non-zero if and
1015-
only if predicate evaluates to non-zero for any of them.
1016-
1017-
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
1018-
}];
1019-
1020-
let assemblyFormat =
1021-
"$value `:` type($value) $pred attr-dict `->` type($res)";
1022-
string llvmBuilder = [{
1023-
createIntrinsicCall(
1024-
builder, getBarrier0IntrinsicID($pred), {$value});
1025-
}];
1026-
}
1027-
1028997
def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
1029998
let summary = "CTA Barrier Synchronization Op";
1030999
let description = [{
@@ -1039,6 +1008,7 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
10391008
- `numberOfThreads`: Specifies the number of threads participating in the barrier.
10401009
When specified, the value must be a multiple of the warp size. If not specified,
10411010
all threads in the CTA participate in the barrier.
1011+
- `reductionOp`
10421012

10431013
The barrier operation guarantees that when the barrier completes, prior memory
10441014
accesses requested by participating threads are performed relative to all threads
@@ -1055,31 +1025,36 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
10551025
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
10561026
}];
10571027

1058-
let arguments = (ins
1059-
Optional<I32>:$barrierId,
1060-
Optional<I32>:$numberOfThreads);
1028+
let arguments = (ins Optional<I32>:$barrierId, Optional<I32>:$numberOfThreads,
1029+
OptionalAttr<BarrierReductionAttr>:$reductionOp,
1030+
Optional<I32>:$reductionPredicate);
10611031
string llvmBuilder = [{
10621032
llvm::Value *id = $barrierId ? $barrierId : builder.getInt32(0);
10631033
if ($numberOfThreads)
10641034
createIntrinsicCall(
10651035
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count,
10661036
{id, $numberOfThreads});
1037+
else if ($reductionOp)
1038+
createIntrinsicCall(
1039+
builder, getBarrierIntrinsicID($reductionOp), {$reductionPredicate});
10671040
else
10681041
createIntrinsicCall(
10691042
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all, {id});
10701043
}];
1044+
let results = (outs Optional<I32>:$res);
1045+
10711046
let hasVerifier = 1;
10721047

1073-
let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
1048+
let assemblyFormat =
1049+
"(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? "
1050+
"($reductionOp^ $reductionPredicate)? (`->` type($res)^)? attr-dict";
10741051

1075-
let builders = [
1076-
OpBuilder<(ins), [{
1077-
return build($_builder, $_state, Value{}, Value{});
1052+
let builders = [OpBuilder<(ins), [{
1053+
return build($_builder, $_state, TypeRange{}, Value{}, Value{}, {}, Value{});
10781054
}]>,
1079-
OpBuilder<(ins "Value":$barrierId), [{
1080-
return build($_builder, $_state, barrierId, Value{});
1081-
}]>
1082-
];
1055+
OpBuilder<(ins "Value":$barrierId), [{
1056+
return build($_builder, $_state, TypeRange{}, barrierId, Value{}, {}, Value{});
1057+
}]>];
10831058
}
10841059

10851060
def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive">

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,15 @@ LogicalResult NVVM::BarrierOp::verify() {
15041504
if (getNumberOfThreads() && !getBarrierId())
15051505
return emitOpError(
15061506
"barrier id is missing, it should be set between 0 to 15");
1507+
1508+
if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
1509+
return emitOpError("reduction are only available for barrier id 0");
1510+
1511+
if ((getReductionOp() && !getReductionPredicate()) ||
1512+
(!getReductionOp() && getReductionPredicate()))
1513+
return emitOpError("reduction predicate and reduction operation must be "
1514+
"specified together");
1515+
15071516
return success();
15081517
}
15091518

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,16 +291,20 @@ static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
291291
llvm_unreachable("Unsupported proxy kinds");
292292
}
293293

294-
static unsigned getBarrier0IntrinsicID(NVVM::Barrier0Pred pred) {
295-
switch (pred) {
296-
case NVVM::Barrier0Pred::AND:
297-
return llvm::Intrinsic::nvvm_barrier0_and;
298-
case NVVM::Barrier0Pred::OR:
299-
return llvm::Intrinsic::nvvm_barrier0_or;
300-
case NVVM::Barrier0Pred::POPC:
301-
return llvm::Intrinsic::nvvm_barrier0_popc;
294+
static unsigned
295+
getBarrierIntrinsicID(std::optional<NVVM::BarrierReduction> reduction) {
296+
if (reduction) {
297+
switch (*reduction) {
298+
case NVVM::BarrierReduction::AND:
299+
return llvm::Intrinsic::nvvm_barrier0_and;
300+
case NVVM::BarrierReduction::OR:
301+
return llvm::Intrinsic::nvvm_barrier0_or;
302+
case NVVM::BarrierReduction::POPC:
303+
return llvm::Intrinsic::nvvm_barrier0_popc;
304+
}
302305
}
303-
llvm_unreachable("Unknown predicate for barrier0");
306+
307+
llvm_unreachable("Unknown reduction operation for barrier");
304308
}
305309

306310
static unsigned getMembarIntrinsicID(NVVM::MemScopeKind scope) {
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: @llvm_nvvm_barrier(
4+
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]], i32 %[[predicate:.*]])
5+
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32, %predicate : i32) {
6+
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
7+
nvvm.barrier
8+
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
9+
nvvm.barrier id = %barID
10+
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
11+
nvvm.barrier id = %barID number_of_threads = %numberOfThreads
12+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.and(i32 %[[predicate]])
13+
%0 = nvvm.barrier #nvvm.reduction<and> %predicate -> i32
14+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.or(i32 %[[predicate]])
15+
%1 = nvvm.barrier #nvvm.reduction<or> %predicate -> i32
16+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.popc(i32 %[[predicate]])
17+
%2 = nvvm.barrier #nvvm.reduction<popc> %predicate -> i32
18+
19+
llvm.return
20+
}

mlir/test/Target/LLVMIR/nvvm/barrier0.mlir

Lines changed: 0 additions & 15 deletions
This file was deleted.

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,6 @@ llvm.func @nvvm_rcp(%0: f32) -> f32 {
166166
llvm.return %1 : f32
167167
}
168168

169-
// CHECK-LABEL: @llvm_nvvm_barrier(
170-
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]])
171-
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32) {
172-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
173-
nvvm.barrier
174-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
175-
nvvm.barrier id = %barID
176-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
177-
nvvm.barrier id = %barID number_of_threads = %numberOfThreads
178-
llvm.return
179-
}
180-
181169
// CHECK-LABEL: @llvm_nvvm_cluster_arrive
182170
llvm.func @llvm_nvvm_cluster_arrive() {
183171
// CHECK: call void @llvm.nvvm.barrier.cluster.arrive()

0 commit comments

Comments
 (0)