@@ -1634,31 +1634,6 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
16341634 return {id, std::move (args)};
16351635}
16361636
1637- #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, dim, mode ) \
1638- llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1639-
1640- #define CP_ASYNC_BULK_TENSOR_REDUCE (op, dim, is_im2col ) \
1641- is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1642- : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1643-
1644- #define GET_CP_ASYNC_BULK_TENSOR_ID (op, dims, is_im2col ) \
1645- [&]() -> auto { \
1646- switch (dims) { \
1647- case 1 : \
1648- return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 1 , tile); \
1649- case 2 : \
1650- return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 2 , tile); \
1651- case 3 : \
1652- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 3 , is_im2col); \
1653- case 4 : \
1654- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 4 , is_im2col); \
1655- case 5 : \
1656- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 5 , is_im2col); \
1657- default : \
1658- llvm_unreachable (" Invalid TensorDim in CpAsyncBulkTensorReduceOp." ); \
1659- } \
1660- }()
1661-
16621637NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs (
16631638 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
16641639 auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
@@ -1682,37 +1657,99 @@ NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
16821657 args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64ZeroValue);
16831658 args.push_back (builder.getInt1 (hasCacheHint));
16841659
1685- llvm::Intrinsic::ID intrinsicID;
1686- int tensorDims = thisOp.getCoordinates ().size ();
1687- bool isIm2Col = thisOp.getMode () == NVVM::TMAStoreMode::IM2COL;
1660+ const unsigned NI = llvm::Intrinsic::not_intrinsic;
1661+ static constexpr llvm::Intrinsic::ID IDTable[][2 ][6 ] = {
1662+ // RedTy::ADD
1663+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
1664+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
1665+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
1666+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
1667+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d},
1668+ {NI, NI, NI,
1669+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
1670+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
1671+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}},
1672+ // RedTy::MIN
1673+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
1674+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
1675+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
1676+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
1677+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d},
1678+ {NI, NI, NI,
1679+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
1680+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
1681+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}},
1682+ // RedTy::MAX
1683+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
1684+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
1685+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
1686+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
1687+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d},
1688+ {NI, NI, NI,
1689+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
1690+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
1691+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}},
1692+ // RedTy::INC
1693+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
1694+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
1695+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
1696+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
1697+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d},
1698+ {NI, NI, NI,
1699+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
1700+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
1701+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}},
1702+ // RedTy::DEC
1703+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
1704+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
1705+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
1706+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
1707+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d},
1708+ {NI, NI, NI,
1709+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
1710+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
1711+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}},
1712+ // RedTy::AND
1713+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
1714+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
1715+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
1716+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
1717+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d},
1718+ {NI, NI, NI,
1719+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
1720+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
1721+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}},
1722+ // RedTy::OR
1723+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
1724+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
1725+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
1726+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
1727+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d},
1728+ {NI, NI, NI,
1729+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
1730+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
1731+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}},
1732+ // RedTy::XOR
1733+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
1734+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
1735+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
1736+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
1737+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d},
1738+ {NI, NI, NI,
1739+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
1740+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
1741+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}},
1742+ };
16881743
1689- using RedTy = NVVM::TMAReduxKind;
1690- switch (thisOp.getRedKind ()) {
1691- case RedTy::ADD:
1692- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_add, tensorDims, isIm2Col);
1693- break ;
1694- case RedTy::MIN:
1695- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_min, tensorDims, isIm2Col);
1696- break ;
1697- case RedTy::MAX:
1698- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_max, tensorDims, isIm2Col);
1699- break ;
1700- case RedTy::INC:
1701- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_inc, tensorDims, isIm2Col);
1702- break ;
1703- case RedTy::DEC:
1704- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_dec, tensorDims, isIm2Col);
1705- break ;
1706- case RedTy::AND:
1707- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_and, tensorDims, isIm2Col);
1708- break ;
1709- case RedTy::OR:
1710- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_or, tensorDims, isIm2Col);
1711- break ;
1712- case RedTy::XOR:
1713- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_xor, tensorDims, isIm2Col);
1714- break ;
1715- }
1744+ static_assert (getMaxEnumValForTMAReduxKind () == std::size (IDTable) - 1 ,
1745+ " TMAReduxKinds must match number of rows in IDTable" );
1746+
1747+ size_t redKind = static_cast <size_t >(thisOp.getRedKind ());
1748+ size_t mode = static_cast <size_t >(thisOp.getMode ());
1749+ size_t dim = thisOp.getCoordinates ().size ();
1750+ llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
1751+ if (intrinsicID == llvm::Intrinsic::not_intrinsic)
1752+ llvm_unreachable (" Invalid intrinsic for CpAsyncBulkTensorReduceOp." );
17161753
17171754 return {intrinsicID, std::move (args)};
17181755}
0 commit comments