Skip to content
Merged
64 changes: 45 additions & 19 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,23 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
}];
}

// Attrs describing the reduction operations for the barrier operation.
def BarrierReductionPopc : I32EnumAttrCase<"POPC", 0, "popc">;
def BarrierReductionAnd : I32EnumAttrCase<"AND", 1, "and">;
def BarrierReductionOr : I32EnumAttrCase<"OR", 2, "or">;

def BarrierReduction
: I32EnumAttr<"BarrierReduction", "NVVM barrier reduction operation",
[BarrierReductionPopc, BarrierReductionAnd,
BarrierReductionOr]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def BarrierReductionAttr
: EnumAttr<NVVM_Dialect, BarrierReduction, "reduction"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
let summary = "CTA Barrier Synchronization Op";
let description = [{
Expand All @@ -935,6 +952,9 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
- `numberOfThreads`: Specifies the number of threads participating in the barrier.
When specified, the value must be a multiple of the warp size. If not specified,
all threads in the CTA participate in the barrier.
- `reductionOp`: specifies the reduction operation (`popc`, `and`, `or`).
- `reductionPredicate`: specifies the predicate to be used with the
`reductionOp`.

The barrier operation guarantees that when the barrier completes, prior memory
accesses requested by participating threads are performed relative to all threads
Expand All @@ -951,31 +971,37 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
}];

let arguments = (ins
Optional<I32>:$barrierId,
Optional<I32>:$numberOfThreads);
let extraClassDeclaration = [{
static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase& builder);
}];

let arguments = (ins Optional<I32>:$barrierId, Optional<I32>:$numberOfThreads,
OptionalAttr<BarrierReductionAttr>:$reductionOp,
Optional<I32>:$reductionPredicate);
string llvmBuilder = [{
llvm::Value *id = $barrierId ? $barrierId : builder.getInt32(0);
if ($numberOfThreads)
createIntrinsicCall(
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count,
{id, $numberOfThreads});
else
createIntrinsicCall(
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all, {id});
auto [id, args] = NVVM::BarrierOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
if ($reductionOp)
$res = createIntrinsicCall(builder, id, args);
else
createIntrinsicCall(builder, id, args);
}];
let results = (outs Optional<I32>:$res);

let hasVerifier = 1;

let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
let assemblyFormat =
"(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? "
"($reductionOp^ $reductionPredicate)? (`->` type($res)^)? attr-dict";

let builders = [
OpBuilder<(ins), [{
return build($_builder, $_state, Value{}, Value{});
let builders = [OpBuilder<(ins), [{
return build($_builder, $_state, TypeRange{}, Value{}, Value{}, {}, Value{});
}]>,
OpBuilder<(ins "Value":$barrierId), [{
return build($_builder, $_state, barrierId, Value{});
}]>
];
OpBuilder<(ins "Value":$barrierId), [{
return build($_builder, $_state, TypeRange{}, barrierId, Value{}, {}, Value{});
}]>];
}

def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive">
Expand Down
42 changes: 42 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,15 @@ LogicalResult NVVM::BarrierOp::verify() {
if (getNumberOfThreads() && !getBarrierId())
return emitOpError(
"barrier id is missing, it should be set between 0 to 15");

if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
return emitOpError("reduction are only available when id is 0");

if ((getReductionOp() && !getReductionPredicate()) ||
(!getReductionOp() && getReductionPredicate()))
return emitOpError("reduction predicate and reduction operation must be "
"specified together");

return success();
}

Expand Down Expand Up @@ -1785,6 +1794,39 @@ std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//

mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::BarrierOp>(op);
llvm::Value *barrierId = thisOp.getBarrierId()
? mt.lookupValue(thisOp.getBarrierId())
: builder.getInt32(0);
llvm::Intrinsic::ID id;
llvm::SmallVector<llvm::Value *> args;
if (thisOp.getNumberOfThreads()) {
id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
args.push_back(barrierId);
args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
} else if (thisOp.getReductionOp()) {
switch (*thisOp.getReductionOp()) {
case NVVM::BarrierReduction::AND:
id = llvm::Intrinsic::nvvm_barrier0_and;
break;
case NVVM::BarrierReduction::OR:
id = llvm::Intrinsic::nvvm_barrier0_or;
break;
case NVVM::BarrierReduction::POPC:
id = llvm::Intrinsic::nvvm_barrier0_popc;
break;
}
args.push_back(mt.lookupValue(thisOp.getReductionPredicate()));
} else {
id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
args.push_back(barrierId);
}

return {id, std::move(args)};
}

mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierInitOp>(op);
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/barrier.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s

// CHECK-LABEL: @llvm_nvvm_barrier(
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]], i32 %[[redOperand:.*]])
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32, %redOperand : i32) {
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
nvvm.barrier
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
nvvm.barrier id = %barID
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
nvvm.barrier id = %barID number_of_threads = %numberOfThreads
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.and(i32 %[[redOperand]])
%0 = nvvm.barrier #nvvm.reduction<and> %redOperand -> i32
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.or(i32 %[[redOperand]])
%1 = nvvm.barrier #nvvm.reduction<or> %redOperand -> i32
// CHECK: %{{.*}} = call i32 @llvm.nvvm.barrier0.popc(i32 %[[redOperand]])
%2 = nvvm.barrier #nvvm.reduction<popc> %redOperand -> i32

llvm.return
}
19 changes: 0 additions & 19 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -166,25 +166,6 @@ llvm.func @nvvm_rcp(%0: f32) -> f32 {
llvm.return %1 : f32
}

// CHECK-LABEL: @llvm_nvvm_barrier0
llvm.func @llvm_nvvm_barrier0() {
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
nvvm.barrier0
llvm.return
}

// CHECK-LABEL: @llvm_nvvm_barrier(
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]])
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32) {
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
nvvm.barrier
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]])
nvvm.barrier id = %barID
// CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]])
nvvm.barrier id = %barID number_of_threads = %numberOfThreads
llvm.return
}

// CHECK-LABEL: @llvm_nvvm_cluster_arrive
llvm.func @llvm_nvvm_cluster_arrive() {
// CHECK: call void @llvm.nvvm.barrier.cluster.arrive()
Expand Down
Loading