Skip to content

Commit 6a1a71b

Browse files
committed
Add getIntrinsicIDAndArgs function
1 parent cfd79ea commit 6a1a71b

File tree

3 files changed

+50
-30
lines changed

3 files changed

+50
-30
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -971,21 +971,22 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
971971
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
972972
}];
973973

974+
let extraClassDeclaration = [{
975+
static mlir::NVVM::IDArgPair
976+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
977+
llvm::IRBuilderBase& builder);
978+
}];
979+
974980
let arguments = (ins Optional<I32>:$barrierId, Optional<I32>:$numberOfThreads,
975981
OptionalAttr<BarrierReductionAttr>:$reductionOp,
976982
Optional<I32>:$reductionOperand);
977983
string llvmBuilder = [{
978-
llvm::Value *id = $barrierId ? $barrierId : builder.getInt32(0);
979-
if ($numberOfThreads)
980-
createIntrinsicCall(
981-
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count,
982-
{id, $numberOfThreads});
983-
else if ($reductionOp)
984-
createIntrinsicCall(
985-
builder, getBarrierIntrinsicID($reductionOp), {$reductionOperand});
986-
else
987-
createIntrinsicCall(
988-
builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all, {id});
984+
auto [id, args] = NVVM::BarrierOp::getIntrinsicIDAndArgs(
985+
*op, moduleTranslation, builder);
986+
if ($reductionOp)
987+
$res = createIntrinsicCall(builder, id, args);
988+
else
989+
createIntrinsicCall(builder, id, args);
989990
}];
990991
let results = (outs Optional<I32>:$res);
991992

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,11 +1518,11 @@ LogicalResult NVVM::BarrierOp::verify() {
15181518
return emitOpError(
15191519
"barrier id is missing, it should be set between 0 to 15");
15201520

1521-
if (getBarrierId() && (getReductionOp() || getReductionPredicate()))
1521+
if (getBarrierId() && (getReductionOp() || getReductionOperand()))
15221522
return emitOpError("reduction are only available when id is 0");
15231523

1524-
if ((getReductionOp() && !getReductionPredicate()) ||
1525-
(!getReductionOp() && getReductionPredicate()))
1524+
if ((getReductionOp() && !getReductionOperand()) ||
1525+
(!getReductionOp() && getReductionOperand()))
15261526
return emitOpError("reduction predicate and reduction operation must be "
15271527
"specified together");
15281528

@@ -1794,6 +1794,41 @@ std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
17941794
// getIntrinsicID/getIntrinsicIDAndArgs methods
17951795
//===----------------------------------------------------------------------===//
17961796

1797+
mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs(
1798+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1799+
auto thisOp = cast<NVVM::BarrierOp>(op);
1800+
llvm::Value *barrierId = thisOp.getBarrierId()
1801+
? mt.lookupValue(thisOp.getBarrierId())
1802+
: builder.getInt32(0);
1803+
llvm::Intrinsic::ID id;
1804+
llvm::SmallVector<llvm::Value *> args;
1805+
if (thisOp.getNumberOfThreads()) {
1806+
id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count;
1807+
args.push_back(barrierId);
1808+
args.push_back(mt.lookupValue(thisOp.getNumberOfThreads()));
1809+
} else if (thisOp.getReductionOp()) {
1810+
switch (*thisOp.getReductionOp()) {
1811+
case NVVM::BarrierReduction::AND:
1812+
id = llvm::Intrinsic::nvvm_barrier0_and;
1813+
break;
1814+
case NVVM::BarrierReduction::OR:
1815+
id = llvm::Intrinsic::nvvm_barrier0_or;
1816+
break;
1817+
case NVVM::BarrierReduction::POPC:
1818+
id = llvm::Intrinsic::nvvm_barrier0_popc;
1819+
break;
1820+
default:
1821+
llvm_unreachable("Unknown reduction operation for barrier");
1822+
}
1823+
args.push_back(mt.lookupValue(thisOp.getReductionOperand()));
1824+
} else {
1825+
id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all;
1826+
args.push_back(barrierId);
1827+
}
1828+
1829+
return {id, std::move(args)};
1830+
}
1831+
17971832
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
17981833
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
17991834
auto thisOp = cast<NVVM::MBarrierInitOp>(op);

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -291,22 +291,6 @@ static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
291291
llvm_unreachable("Unsupported proxy kinds");
292292
}
293293

294-
static unsigned
295-
getBarrierIntrinsicID(std::optional<NVVM::BarrierReduction> reduction) {
296-
if (reduction) {
297-
switch (*reduction) {
298-
case NVVM::BarrierReduction::AND:
299-
return llvm::Intrinsic::nvvm_barrier0_and;
300-
case NVVM::BarrierReduction::OR:
301-
return llvm::Intrinsic::nvvm_barrier0_or;
302-
case NVVM::BarrierReduction::POPC:
303-
return llvm::Intrinsic::nvvm_barrier0_popc;
304-
}
305-
}
306-
307-
llvm_unreachable("Unknown reduction operation for barrier");
308-
}
309-
310294
static unsigned getMembarIntrinsicID(NVVM::MemScopeKind scope) {
311295
switch (scope) {
312296
case NVVM::MemScopeKind::CTA:

0 commit comments

Comments
 (0)