Skip to content

Commit 6681e21

Browse files
clementvalgrypp
authored andcommitted
[mlir][NVVM] Add support for barrier0-reduction operation (llvm#167036)
Add support for `nvvm.barrier0.[popc|and|or]` operation. It is added as a separate operation since `Barrier0Op` has no result. https://docs.nvidia.com/cuda/nvvm-ir-spec/#barrier-and-memory-fence This will be used in CUDA Fortran lowering: https://github.com/llvm/llvm-project/blob/49f55f4991227f3c7a2b8161bbf45c74b7023944/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp#L1081 And could be used later in the CUDA C/C++ with CIR https://github.com/llvm/llvm-project/blob/49f55f4991227f3c7a2b8161bbf45c74b7023944/clang/lib/Headers/__clang_cuda_device_functions.h#L524 --------- Co-authored-by: Guray Ozen <[email protected]>
1 parent 6db7a27 commit 6681e21

File tree

4 files changed

+107
-38
lines changed

4 files changed

+107
-38
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,23 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
921921
}];
922922
}
923923

924+
// Attrs describing the reduction operations for the barrier operation.
925+
def BarrierReductionPopc : I32EnumAttrCase<"POPC", 0, "popc">;
926+
def BarrierReductionAnd : I32EnumAttrCase<"AND", 1, "and">;
927+
def BarrierReductionOr : I32EnumAttrCase<"OR", 2, "or">;
928+
929+
def BarrierReduction
930+
: I32EnumAttr<"BarrierReduction", "NVVM barrier reduction operation",
931+
[BarrierReductionPopc, BarrierReductionAnd,
932+
BarrierReductionOr]> {
933+
let genSpecializedAttr = 0;
934+
let cppNamespace = "::mlir::NVVM";
935+
}
936+
def BarrierReductionAttr
937+
: EnumAttr<NVVM_Dialect, BarrierReduction, "reduction"> {
938+
let assemblyFormat = "`<` $value `>`";
939+
}
940+
924941
def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
925942
let summary = "CTA Barrier Synchronization Op";
926943
let description = [{
@@ -935,6 +952,9 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
935952
- `numberOfThreads`: Specifies the number of threads participating in the barrier.
936953
When specified, the value must be a multiple of the warp size. If not specified,
937954
all threads in the CTA participate in the barrier.
955+
- `reductionOp`: specifies the reduction operation (`popc`, `and`, `or`).
956+
- `reductionPredicate`: specifies the predicate to be used with the
957+
`reductionOp`.
938958

939959
The barrier operation guarantees that when the barrier completes, prior memory
940960
accesses requested by participating threads are performed relative to all threads
@@ -951,31 +971,37 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
951971
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
952972
}];
953973

954-
let arguments = (ins
955-
Optional<I32>:$barrierId,
956-
Optional<I32>:$numberOfThreads);
974+
let extraClassDeclaration = [{
975+
static mlir::NVVM::IDArgPair
976+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
977+
llvm::IRBuilderBase& builder);
978+
}];
979+
980+
let arguments = (ins Optional<I32>:$barrierId, Optional<I32>:$numberOfThreads,
981+
OptionalAttr<BarrierReductionAttr>:$reductionOp,
982+
Optional<I32>:$reductionPredicate);
957983
string llvmBuilder = [{
958-
llvm::Value *id = $barrierId ? $barrierId : builder.getInt32(0);
959-
if ($numberOfThreads)
960-
createIntrinsicCall(
961-
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count,
962-
{id, $numberOfThreads});
963-
else
964-
createIntrinsicCall(
965-
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all, {id});
984+
auto [id, args] = NVVM::BarrierOp::getIntrinsicIDAndArgs(
985+
*op, moduleTranslation, builder);
986+
if ($reductionOp)
987+
$res = createIntrinsicCall(builder, id, args);
988+
else
989+
createIntrinsicCall(builder, id, args);
966990
}];
991+
let results = (outs Optional<I32>:$res);
992+
967993
let hasVerifier = 1;
968994

969-
let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
995+
let assemblyFormat =
996+
"(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? "
997+
"($reductionOp^ $reductionPredicate)? (`->` type($res)^)? attr-dict";
970998

971-
let builders = [
972-
OpBuilder<(ins), [{
973-
return build($_builder, $_state, Value{}, Value{});
999+
let builders = [OpBuilder<(ins), [{
1000+
return build($_builder, $_state, TypeRange{}, Value{}, Value{}, {}, Value{});
9741001
}]>,
975-
OpBuilder<(ins "Value":$barrierId), [{
976-
return build($_builder, $_state, barrierId, Value{});
977-
}]>
978-
];
1002+
OpBuilder<(ins "Value":$barrierId), [{
1003+
return build($_builder, $_state, TypeRange{}, barrierId, Value{}, {}, Value{});
1004+
}]>];
9791005
}
9801006

9811007
def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive">

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,6 +1517,15 @@ LogicalResult NVVM::BarrierOp::verify() {
15171517
if (getNumberOfThreads() && !getBarrierId())
15181518
return emitOpError(
15191519
"barrier id is missing, it should be set between 0 to 15");
1520+
1521+
if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
1522+
return emitOpError("reduction are only available when id is 0");
1523+
1524+
if ((getReductionOp() && !getReductionPredicate()) ||
1525+
(!getReductionOp() && getReductionPredicate()))
1526+
return emitOpError("reduction predicate and reduction operation must be "
1527+
"specified together");
1528+
15201529
return success();
15211530
}
15221531

@@ -1785,6 +1794,39 @@ std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
17851794
// getIntrinsicID/getIntrinsicIDAndArgs methods
17861795
//===----------------------------------------------------------------------===//
17871796

1797+
mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
1798+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1799+
auto thisOp = cast<NVVM::BarrierOp>(op);
1800+
llvm::Value *barrierId = thisOp.getBarrierId()
1801+
? mt.lookupValue(thisOp.getBarrierId())
1802+
: builder.getInt32(0);
1803+
llvm::Intrinsic::ID id;
1804+
llvm::SmallVector<llvm::Value *> args;
1805+
if (thisOp.getNumberOfThreads()) {
1806+
id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
1807+
args.push_back(barrierId);
1808+
args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
1809+
} else if (thisOp.getReductionOp()) {
1810+
switch (*thisOp.getReductionOp()) {
1811+
case NVVM::BarrierReduction::AND:
1812+
id = llvm::Intrinsic::nvvm_barrier0_and;
1813+
break;
1814+
case NVVM::BarrierReduction::OR:
1815+
id = llvm::Intrinsic::nvvm_barrier0_or;
1816+
break;
1817+
case NVVM::BarrierReduction::POPC:
1818+
id = llvm::Intrinsic::nvvm_barrier0_popc;
1819+
break;
1820+
}
1821+
args.push_back(mt.lookupValue(thisOp.getReductionPredicate()));
1822+
} else {
1823+
id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
1824+
args.push_back(barrierId);
1825+
}
1826+
1827+
return {id, std::move(args)};
1828+
}
1829+
17881830
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
17891831
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
17901832
auto thisOp = cast<NVVM::MBarrierInitOp>(op);
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: @llvm_nvvm_barrier(
4+
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]], i32 %[[redOperand:.*]])
5+
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32, %redOperand : i32) {
6+
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
7+
nvvm.barrier
8+
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
9+
nvvm.barrier id = %barID
10+
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
11+
nvvm.barrier id = %barID number_of_threads = %numberOfThreads
12+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.and(i32 %[[redOperand]])
13+
%0 = nvvm.barrier #nvvm.reduction<and> %redOperand -> i32
14+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.or(i32 %[[redOperand]])
15+
%1 = nvvm.barrier #nvvm.reduction<or> %redOperand -> i32
16+
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.popc(i32 %[[redOperand]])
17+
%2 = nvvm.barrier #nvvm.reduction<popc> %redOperand -> i32
18+
19+
llvm.return
20+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -166,25 +166,6 @@ llvm.func @nvvm_rcp(%0: f32) -> f32 {
166166
llvm.return %1 : f32
167167
}
168168

169-
// CHECK-LABEL: @llvm_nvvm_barrier0
170-
llvm.func @llvm_nvvm_barrier0() {
171-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
172-
nvvm.barrier0
173-
llvm.return
174-
}
175-
176-
// CHECK-LABEL: @llvm_nvvm_barrier(
177-
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]])
178-
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32) {
179-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
180-
nvvm.barrier
181-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
182-
nvvm.barrier id = %barID
183-
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
184-
nvvm.barrier id = %barID number_of_threads = %numberOfThreads
185-
llvm.return
186-
}
187-
188169
// CHECK-LABEL: @llvm_nvvm_cluster_arrive
189170
llvm.func @llvm_nvvm_cluster_arrive() {
190171
// CHECK: call void @llvm.nvvm.barrier.cluster.arrive()

0 commit comments

Comments
 (0)