Skip to content

Commit 53c785d

Browse files
[NFC][MLIR] Refactor NVVM_CpAsyncBulkGlobalToSharedClusterOp's lowering (#162611)
1 parent 24d5c4f commit 53c785d

File tree

2 files changed

+41
-28
lines changed

2 files changed

+41
-28
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3233,35 +3233,15 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
32333233
attr-dict `:` type($dstMem) `,` type($srcMem)
32343234
}];
32353235

3236+
let extraClassDeclaration = [{
3237+
static mlir::NVVM::IDArgPair
3238+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3239+
llvm::IRBuilderBase& builder);
3240+
}];
32363241
string llvmBuilder = [{
3237-
// Arguments to the intrinsic:
3238-
// dst, mbar, src, size
3239-
// multicast_mask, cache_hint,
3240-
// flag for multicast_mask,
3241-
// flag for cache_hint
3242-
llvm::SmallVector<llvm::Value *> translatedOperands;
3243-
translatedOperands.push_back($dstMem);
3244-
translatedOperands.push_back($mbar);
3245-
translatedOperands.push_back($srcMem);
3246-
translatedOperands.push_back($size);
3247-
3248-
// Multicast, if available
3249-
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
3250-
auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0);
3251-
bool isMulticast = op.getMulticastMask() ? true : false;
3252-
translatedOperands.push_back(isMulticast ? $multicastMask : i16Unused);
3253-
3254-
// Cachehint, if available
3255-
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
3256-
bool isCacheHint = op.getL2CacheHint() ? true : false;
3257-
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
3258-
3259-
// Flag arguments for multicast and cachehint
3260-
translatedOperands.push_back(builder.getInt1(isMulticast));
3261-
translatedOperands.push_back(builder.getInt1(isCacheHint));
3262-
3263-
createIntrinsicCall(builder,
3264-
llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, translatedOperands);
3242+
auto [id, args] = NVVM::CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
3243+
*op, moduleTranslation, builder);
3244+
createIntrinsicCall(builder, id, args);
32653245
}];
32663246
}
32673247

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,6 +1593,39 @@ mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
15931593
return {id, std::move(args)};
15941594
}
15951595

1596+
mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
1597+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1598+
auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
1599+
llvm::SmallVector<llvm::Value *> args;
1600+
1601+
// Fill the Intrinsic Args: dst, mbar, src, size.
1602+
args.push_back(mt.lookupValue(thisOp.getDstMem()));
1603+
args.push_back(mt.lookupValue(thisOp.getMbar()));
1604+
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1605+
args.push_back(mt.lookupValue(thisOp.getSize()));
1606+
1607+
// Multicast mask, if available.
1608+
mlir::Value multicastMask = thisOp.getMulticastMask();
1609+
const bool hasMulticastMask = static_cast<bool>(multicastMask);
1610+
llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
1611+
args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
1612+
1613+
// Cache hint, if available.
1614+
mlir::Value cacheHint = thisOp.getL2CacheHint();
1615+
const bool hasCacheHint = static_cast<bool>(cacheHint);
1616+
llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0);
1617+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1618+
1619+
// Flag arguments for multicast and cachehint.
1620+
args.push_back(builder.getInt1(hasMulticastMask));
1621+
args.push_back(builder.getInt1(hasCacheHint));
1622+
1623+
llvm::Intrinsic::ID id =
1624+
llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
1625+
1626+
return {id, std::move(args)};
1627+
}
1628+
15961629
mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
15971630
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
15981631
auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);

0 commit comments

Comments
 (0)