@@ -1641,28 +1641,23 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
16411641 is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
16421642 : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
16431643
1644- #define GET_CP_ASYNC_BULK_TENSOR_ID (iid, op, dims, is_im2col ) \
1645- switch (dims) { \
1646- case 1 : \
1647- iid = CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 1 , tile); \
1648- break ; \
1649- case 2 : \
1650- iid = CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 2 , tile); \
1651- break ; \
1652- case 3 : \
1653- iid = CP_ASYNC_BULK_TENSOR_REDUCE (op, 3 , is_im2col); \
1654- break ; \
1655- case 4 : \
1656- iid = CP_ASYNC_BULK_TENSOR_REDUCE (op, 4 , is_im2col); \
1657- break ; \
1658- case 5 : \
1659- iid = CP_ASYNC_BULK_TENSOR_REDUCE (op, 5 , is_im2col); \
1660- break ; \
1661- default : \
1662- llvm_unreachable (" Invalid TensorDim in CpAsyncBulkTensorReduceOp." ); \
1663- break ; \
1664- } \
1665- break ;
1644+ #define GET_CP_ASYNC_BULK_TENSOR_ID (op, dims, is_im2col ) \
1645+ [&]() -> auto { \
1646+ switch (dims) { \
1647+ case 1 : \
1648+ return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 1 , tile); \
1649+ case 2 : \
1650+ return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 2 , tile); \
1651+ case 3 : \
1652+ return CP_ASYNC_BULK_TENSOR_REDUCE (op, 3 , is_im2col); \
1653+ case 4 : \
1654+ return CP_ASYNC_BULK_TENSOR_REDUCE (op, 4 , is_im2col); \
1655+ case 5 : \
1656+ return CP_ASYNC_BULK_TENSOR_REDUCE (op, 5 , is_im2col); \
1657+ default : \
1658+ llvm_unreachable (" Invalid TensorDim in CpAsyncBulkTensorReduceOp." ); \
1659+ } \
1660+ }()
16661661
16671662NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs (
16681663 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
@@ -1677,41 +1672,49 @@ NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
16771672 args.push_back (mt.lookupValue (thisOp.getSrcMem ()));
16781673 args.push_back (mt.lookupValue (thisOp.getTmaDescriptor ()));
16791674
1680- for (auto v : thisOp.getCoordinates ())
1675+ for (Value v : thisOp.getCoordinates ())
16811676 args.push_back (mt.lookupValue (v));
16821677
16831678 mlir::Value cacheHint = thisOp.getL2CacheHint ();
16841679 const bool hasCacheHint = static_cast <bool >(cacheHint);
1685- llvm::Value *i64Unused =
1680+ llvm::Value *i64ZeroValue =
16861681 llvm::ConstantInt::get (llvm::Type::getInt64Ty (ctx), 0 );
1687- args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64Unused );
1682+ args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64ZeroValue );
16881683 args.push_back (builder.getInt1 (hasCacheHint));
16891684
1690- llvm::Intrinsic::ID iid ;
1685+ llvm::Intrinsic::ID intrinsicID ;
16911686 int tensorDims = thisOp.getCoordinates ().size ();
16921687 bool isIm2Col = thisOp.getMode () == NVVM::TMAStoreMode::IM2COL;
16931688
16941689 using RedTy = NVVM::TMAReduxKind;
16951690 switch (thisOp.getRedKind ()) {
16961691 case RedTy::ADD:
1697- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_add, tensorDims, isIm2Col);
1692+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_add, tensorDims, isIm2Col);
1693+ break ;
16981694 case RedTy::MIN:
1699- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_min, tensorDims, isIm2Col);
1695+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_min, tensorDims, isIm2Col);
1696+ break ;
17001697 case RedTy::MAX:
1701- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_max, tensorDims, isIm2Col);
1698+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_max, tensorDims, isIm2Col);
1699+ break ;
17021700 case RedTy::INC:
1703- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_inc, tensorDims, isIm2Col);
1701+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_inc, tensorDims, isIm2Col);
1702+ break ;
17041703 case RedTy::DEC:
1705- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_dec, tensorDims, isIm2Col);
1704+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_dec, tensorDims, isIm2Col);
1705+ break ;
17061706 case RedTy::AND:
1707- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_and, tensorDims, isIm2Col);
1707+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_and, tensorDims, isIm2Col);
1708+ break ;
17081709 case RedTy::OR:
1709- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_or, tensorDims, isIm2Col);
1710+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_or, tensorDims, isIm2Col);
1711+ break ;
17101712 case RedTy::XOR:
1711- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_xor, tensorDims, isIm2Col);
1713+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_xor, tensorDims, isIm2Col);
1714+ break ;
17121715 }
17131716
1714- return {iid , std::move (args)};
1717+ return {intrinsicID , std::move (args)};
17151718}
17161719
17171720#define _none
0 commit comments