Skip to content

Commit dd5fe88

Browse files
committed
Add getIntrinsicIDAndArgs function
1 parent 79099d0 commit dd5fe88

File tree

3 files changed

+50
-30
lines changed

3 files changed

+50
-30
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,21 +1027,22 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
10271027
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
10281028
}];
10291029

1030+
let extraClassDeclaration = [{
1031+
static mlir::NVVM::IDArgPair
1032+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1033+
llvm::IRBuilderBase& builder);
1034+
}];
1035+
10301036
let arguments = (ins Optional<I32>:$barrierId, Optional<I32>:$numberOfThreads,
10311037
OptionalAttr<BarrierReductionAttr>:$reductionOp,
10321038
Optional<I32>:$reductionOperand);
10331039
string llvmBuilder = [{
1034-
llvm::Value *id = $barrierId ? $barrierId : builder.getInt32(0);
1035-
if ($numberOfThreads)
1036-
createIntrinsicCall(
1037-
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count,
1038-
{id, $numberOfThreads});
1039-
else if ($reductionOp)
1040-
createIntrinsicCall(
1041-
builder, getBarrierIntrinsicID($reductionOp), {$reductionOperand});
1042-
else
1043-
createIntrinsicCall(
1044-
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all, {id});
1040+
auto [id, args] = NVVM::BarrierOp::getIntrinsicIDAndArgs(
1041+
*op, moduleTranslation, builder);
1042+
if ($reductionOp)
1043+
$res = createIntrinsicCall(builder, id, args);
1044+
else
1045+
createIntrinsicCall(builder, id, args);
10451046
}];
10461047
let results = (outs Optional<I32>:$res);
10471048

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,11 +1505,11 @@ LogicalResult NVVM::BarrierOp::verify() {
15051505
return emitOpError(
15061506
"barrier id is missing, it should be set between 0 to 15");
15071507

1508-
if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
1508+
if (getBarrierId() && (getReductionOp() || getReductionOperand()))
15091509
return emitOpError("reduction are only available when id is 0");
15101510

1511-
if ((getReductionOp() && !getReductionPredicate()) ||
1512-
(!getReductionOp() && getReductionPredicate()))
1511+
if ((getReductionOp() && !getReductionOperand()) ||
1512+
(!getReductionOp() && getReductionOperand()))
15131513
return emitOpError("reduction predicate and reduction operation must be "
15141514
"specified together");
15151515

@@ -1770,6 +1770,41 @@ static bool isPtrInSharedCTASpace(mlir::Value ptr) {
17701770
return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
17711771
}
17721772

1773+
mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
1774+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1775+
auto thisOp = cast<NVVM::BarrierOp>(op);
1776+
llvm::Value *barrierId = thisOp.getBarrierId()
1777+
? mt.lookupValue(thisOp.getBarrierId())
1778+
: builder.getInt32(0);
1779+
llvm::Intrinsic::ID id;
1780+
llvm::SmallVector<llvm::Value *> args;
1781+
if (thisOp.getNumberOfThreads()) {
1782+
id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
1783+
args.push_back(barrierId);
1784+
args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
1785+
} else if (thisOp.getReductionOp()) {
1786+
switch (*thisOp.getReductionOp()) {
1787+
case NVVM::BarrierReduction::AND:
1788+
id = llvm::Intrinsic::nvvm_barrier0_and;
1789+
break;
1790+
case NVVM::BarrierReduction::OR:
1791+
id = llvm::Intrinsic::nvvm_barrier0_or;
1792+
break;
1793+
case NVVM::BarrierReduction::POPC:
1794+
id = llvm::Intrinsic::nvvm_barrier0_popc;
1795+
break;
1796+
default:
1797+
llvm_unreachable("Unknown reduction operation for barrier");
1798+
}
1799+
args.push_back(mt.lookupValue(thisOp.getReductionOperand()));
1800+
} else {
1801+
id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
1802+
args.push_back(barrierId);
1803+
}
1804+
1805+
return {id, std::move(args)};
1806+
}
1807+
17731808
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
17741809
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
17751810
auto thisOp = cast<NVVM::MBarrierInitOp>(op);

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -291,22 +291,6 @@ static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
291291
llvm_unreachable("Unsupported proxy kinds");
292292
}
293293

294-
static unsigned
295-
getBarrierIntrinsicID(std::optional<NVVM::BarrierReduction> reduction) {
296-
if (reduction) {
297-
switch (*reduction) {
298-
case NVVM::BarrierReduction::AND:
299-
return llvm::Intrinsic::nvvm_barrier0_and;
300-
case NVVM::BarrierReduction::OR:
301-
return llvm::Intrinsic::nvvm_barrier0_or;
302-
case NVVM::BarrierReduction::POPC:
303-
return llvm::Intrinsic::nvvm_barrier0_popc;
304-
}
305-
}
306-
307-
llvm_unreachable("Unknown reduction operation for barrier");
308-
}
309-
310294
static unsigned getMembarIntrinsicID(NVVM::MemScopeKind scope) {
311295
switch (scope) {
312296
case NVVM::MemScopeKind::CTA:

0 commit comments

Comments
 (0)