Skip to content

Commit 0593fd9

Browse files
committed
[MLIR][NVVM] Fix undef in cp.async.bulk.tensor.reduce Op
This PR moves the LLVMIR lowering code of the NVVM dialect cp.async.bulk.tensor.reduce Op to `NVVMDialect.cpp` and fixes the usage of `undef` in the lowering since it is now deprecated. The tests are updated accordingly.
1 parent 97d4c7d commit 0593fd9

File tree

3 files changed

+130
-117
lines changed

3 files changed

+130
-117
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2795,35 +2795,17 @@ def NVVM_CpAsyncBulkTensorReduceOp :
27952795
}];
27962796

27972797
let extraClassDeclaration = [{
2798-
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims,
2799-
NVVM::TMAReduxKind kind,
2800-
bool isIm2Col);
2798+
static mlir::NVVM::IDArgPair
2799+
getIntrinsicIDAndArgs(Operation &op,
2800+
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase& builder);
28012801
}];
28022802

28032803
let hasVerifier = 1;
28042804

28052805
string llvmBuilder = [{
2806-
// Arguments to the intrinsic:
2807-
// shared_mem_ptr, tmaDesc, tensorDims
2808-
// cache_hint(if applicable) and flag(boolean)
2809-
llvm::SmallVector<llvm::Value *> translatedOperands;
2810-
translatedOperands.push_back($srcMem);
2811-
translatedOperands.push_back($tmaDescriptor);
2812-
2813-
for (auto v : op.getCoordinates())
2814-
translatedOperands.push_back(moduleTranslation.lookupValue(v));
2815-
2816-
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2817-
auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));
2818-
2819-
bool isCacheHint = op.getL2CacheHint() ? true : false;
2820-
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
2821-
translatedOperands.push_back(builder.getInt1(isCacheHint));
2822-
2823-
auto intId = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicID(
2824-
op.getCoordinates().size(), $redKind,
2825-
(op.getMode() == NVVM::TMAStoreMode::IM2COL));
2826-
createIntrinsicCall(builder, intId, translatedOperands);
2806+
auto [id, args] = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
2807+
*op, moduleTranslation, builder);
2808+
createIntrinsicCall(builder, id, args);
28272809
}];
28282810
}
28292811

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,46 +1641,77 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
16411641
is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
16421642
: CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
16431643

1644-
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1645-
[&]() -> auto { \
1646-
switch (dims) { \
1647-
case 1: \
1648-
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1649-
case 2: \
1650-
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1651-
case 3: \
1652-
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1653-
case 4: \
1654-
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1655-
case 5: \
1656-
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1657-
default: \
1658-
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1659-
} \
1660-
}()
1644+
#define GET_CP_ASYNC_BULK_TENSOR_ID(iid, op, dims, is_im2col) \
1645+
switch (dims) { \
1646+
case 1: \
1647+
iid = CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1648+
break; \
1649+
case 2: \
1650+
iid = CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1651+
break; \
1652+
case 3: \
1653+
iid = CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1654+
break; \
1655+
case 4: \
1656+
iid = CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1657+
break; \
1658+
case 5: \
1659+
iid = CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1660+
break; \
1661+
default: \
1662+
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1663+
break; \
1664+
} \
1665+
break;
1666+
1667+
NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
1668+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1669+
auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
1670+
llvm::LLVMContext &ctx = mt.getLLVMContext();
1671+
1672+
llvm::SmallVector<llvm::Value *> args;
1673+
1674+
// Arguments to the intrinsic:
1675+
// shared_mem_ptr, tmaDesc, tensorDims
1676+
// cache_hint(if applicable) and flag(boolean)
1677+
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1678+
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1679+
1680+
for (auto v : thisOp.getCoordinates())
1681+
args.push_back(mt.lookupValue(v));
1682+
1683+
mlir::Value cacheHint = thisOp.getL2CacheHint();
1684+
const bool hasCacheHint = static_cast<bool>(cacheHint);
1685+
llvm::Value *i64Unused =
1686+
llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
1687+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1688+
args.push_back(builder.getInt1(hasCacheHint));
1689+
1690+
llvm::Intrinsic::ID iid;
1691+
int tensorDims = thisOp.getCoordinates().size();
1692+
bool isIm2Col = thisOp.getMode() == NVVM::TMAStoreMode::IM2COL;
16611693

1662-
llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1663-
int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
16641694
using RedTy = NVVM::TMAReduxKind;
1665-
switch (kind) {
1695+
switch (thisOp.getRedKind()) {
16661696
case RedTy::ADD:
1667-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1697+
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_add, tensorDims, isIm2Col);
16681698
case RedTy::MIN:
1669-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1699+
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_min, tensorDims, isIm2Col);
16701700
case RedTy::MAX:
1671-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1701+
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_max, tensorDims, isIm2Col);
16721702
case RedTy::INC:
1673-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1703+
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_inc, tensorDims, isIm2Col);
16741704
case RedTy::DEC:
1675-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1705+
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_dec, tensorDims, isIm2Col);
16761706
case RedTy::AND:
1677-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1707+
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_and, tensorDims, isIm2Col);
16781708
case RedTy::OR:
1679-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1709+
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_or, tensorDims, isIm2Col);
16801710
case RedTy::XOR:
1681-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1711+
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_xor, tensorDims, isIm2Col);
16821712
}
1683-
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1713+
1714+
return {iid, std::move(args)};
16841715
}
16851716

16861717
#define _none

0 commit comments

Comments
 (0)