Skip to content

Commit 6ad4d3c

Browse files
committed
address comments
1 parent 0593fd9 commit 6ad4d3c

File tree

1 file changed

+38
-35
lines changed

1 file changed

+38
-35
lines changed

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,28 +1641,23 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
16411641
is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
16421642
: CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
16431643

1644-
#define GET_CP_ASYNC_BULK_TENSOR_ID(iid, op, dims, is_im2col) \
1645-
switch (dims) { \
1646-
case 1: \
1647-
iid = CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1648-
break; \
1649-
case 2: \
1650-
iid = CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1651-
break; \
1652-
case 3: \
1653-
iid = CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1654-
break; \
1655-
case 4: \
1656-
iid = CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1657-
break; \
1658-
case 5: \
1659-
iid = CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1660-
break; \
1661-
default: \
1662-
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1663-
break; \
1664-
} \
1665-
break;
1644+
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1645+
[&]() -> auto{ \
1646+
switch (dims) { \
1647+
case 1: \
1648+
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1649+
case 2: \
1650+
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1651+
case 3: \
1652+
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1653+
case 4: \
1654+
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1655+
case 5: \
1656+
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1657+
default: \
1658+
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1659+
} \
1660+
}()
16661661

16671662
NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
16681663
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
@@ -1677,41 +1672,49 @@ NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
16771672
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
16781673
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
16791674

1680-
for (auto v : thisOp.getCoordinates())
1675+
for (Value v : thisOp.getCoordinates())
16811676
args.push_back(mt.lookupValue(v));
16821677

16831678
mlir::Value cacheHint = thisOp.getL2CacheHint();
16841679
const bool hasCacheHint = static_cast<bool>(cacheHint);
1685-
llvm::Value *i64Unused =
1680+
llvm::Value *i64ZeroValue =
16861681
llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
1687-
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1682+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
16881683
args.push_back(builder.getInt1(hasCacheHint));
16891684

1690-
llvm::Intrinsic::ID iid;
1685+
llvm::Intrinsic::ID intrinsicID;
16911686
int tensorDims = thisOp.getCoordinates().size();
16921687
bool isIm2Col = thisOp.getMode() == NVVM::TMAStoreMode::IM2COL;
16931688

16941689
using RedTy = NVVM::TMAReduxKind;
16951690
switch (thisOp.getRedKind()) {
16961691
case RedTy::ADD:
1697-
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_add, tensorDims, isIm2Col);
1692+
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1693+
break;
16981694
case RedTy::MIN:
1699-
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_min, tensorDims, isIm2Col);
1695+
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1696+
break;
17001697
case RedTy::MAX:
1701-
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_max, tensorDims, isIm2Col);
1698+
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1699+
break;
17021700
case RedTy::INC:
1703-
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_inc, tensorDims, isIm2Col);
1701+
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1702+
break;
17041703
case RedTy::DEC:
1705-
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_dec, tensorDims, isIm2Col);
1704+
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1705+
break;
17061706
case RedTy::AND:
1707-
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_and, tensorDims, isIm2Col);
1707+
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1708+
break;
17081709
case RedTy::OR:
1709-
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_or, tensorDims, isIm2Col);
1710+
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1711+
break;
17101712
case RedTy::XOR:
1711-
GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_xor, tensorDims, isIm2Col);
1713+
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1714+
break;
17121715
}
17131716

1714-
return {iid, std::move(args)};
1717+
return {intrinsicID, std::move(args)};
17151718
}
17161719

17171720
#define _none

0 commit comments

Comments
 (0)