@@ -1641,46 +1641,77 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
16411641 is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
16421642 : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
16431643
1644- #define GET_CP_ASYNC_BULK_TENSOR_ID (op, dims, is_im2col ) \
1645- [&]() -> auto { \
1646- switch (dims) { \
1647- case 1 : \
1648- return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 1 , tile); \
1649- case 2 : \
1650- return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 2 , tile); \
1651- case 3 : \
1652- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 3 , is_im2col); \
1653- case 4 : \
1654- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 4 , is_im2col); \
1655- case 5 : \
1656- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 5 , is_im2col); \
1657- default : \
1658- llvm_unreachable (" Invalid TensorDim in CpAsyncBulkTensorReduceOp." ); \
1659- } \
1660- }()
1644+ #define GET_CP_ASYNC_BULK_TENSOR_ID (iid, op, dims, is_im2col ) \
1645+ switch (dims) { \
1646+ case 1 : \
1647+ iid = CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 1 , tile); \
1648+ break ; \
1649+ case 2 : \
1650+ iid = CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 2 , tile); \
1651+ break ; \
1652+ case 3 : \
1653+ iid = CP_ASYNC_BULK_TENSOR_REDUCE (op, 3 , is_im2col); \
1654+ break ; \
1655+ case 4 : \
1656+ iid = CP_ASYNC_BULK_TENSOR_REDUCE (op, 4 , is_im2col); \
1657+ break ; \
1658+ case 5 : \
1659+ iid = CP_ASYNC_BULK_TENSOR_REDUCE (op, 5 , is_im2col); \
1660+ break ; \
1661+ default : \
1662+ llvm_unreachable (" Invalid TensorDim in CpAsyncBulkTensorReduceOp." ); \
1663+ break ; \
1664+ } \
1665+ break ;
1666+
1667+ NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs (
1668+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1669+ auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
1670+ llvm::LLVMContext &ctx = mt.getLLVMContext ();
1671+
1672+ llvm::SmallVector<llvm::Value *> args;
1673+
1674+ // Arguments to the intrinsic:
1675+ // shared_mem_ptr, tmaDesc, tensorDims
1676+ // cache_hint(if applicable) and flag(boolean)
1677+ args.push_back (mt.lookupValue (thisOp.getSrcMem ()));
1678+ args.push_back (mt.lookupValue (thisOp.getTmaDescriptor ()));
1679+
1680+ for (auto v : thisOp.getCoordinates ())
1681+ args.push_back (mt.lookupValue (v));
1682+
1683+ mlir::Value cacheHint = thisOp.getL2CacheHint ();
1684+ const bool hasCacheHint = static_cast <bool >(cacheHint);
1685+ llvm::Value *i64Unused =
1686+ llvm::ConstantInt::get (llvm::Type::getInt64Ty (ctx), 0 );
1687+ args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64Unused);
1688+ args.push_back (builder.getInt1 (hasCacheHint));
1689+
1690+ llvm::Intrinsic::ID iid;
1691+ int tensorDims = thisOp.getCoordinates ().size ();
1692+ bool isIm2Col = thisOp.getMode () == NVVM::TMAStoreMode::IM2COL;
16611693
1662- llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID (
1663- int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
16641694 using RedTy = NVVM::TMAReduxKind;
1665- switch (kind ) {
1695+ switch (thisOp. getRedKind () ) {
16661696 case RedTy::ADD:
1667- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_add, tensorDims, isIm2Col);
1697+ GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_add, tensorDims, isIm2Col);
16681698 case RedTy::MIN:
1669- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_min, tensorDims, isIm2Col);
1699+ GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_min, tensorDims, isIm2Col);
16701700 case RedTy::MAX:
1671- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_max, tensorDims, isIm2Col);
1701+ GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_max, tensorDims, isIm2Col);
16721702 case RedTy::INC:
1673- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_inc, tensorDims, isIm2Col);
1703+ GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_inc, tensorDims, isIm2Col);
16741704 case RedTy::DEC:
1675- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_dec, tensorDims, isIm2Col);
1705+ GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_dec, tensorDims, isIm2Col);
16761706 case RedTy::AND:
1677- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_and, tensorDims, isIm2Col);
1707+ GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_and, tensorDims, isIm2Col);
16781708 case RedTy::OR:
1679- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_or, tensorDims, isIm2Col);
1709+ GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_or, tensorDims, isIm2Col);
16801710 case RedTy::XOR:
1681- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_xor, tensorDims, isIm2Col);
1711+ GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_xor, tensorDims, isIm2Col);
16821712 }
1683- llvm_unreachable (" Invalid Reduction Op for CpAsyncBulkTensorReduceOp" );
1713+
1714+ return {iid, std::move (args)};
16841715}
16851716
16861717#define _none
0 commit comments