@@ -1802,53 +1802,148 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
18021802 return {id, std::move (args)};
18031803}
18041804
1805- #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, dim, mode ) \
1806- llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1805+ NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs (
1806+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1807+ auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
1808+ llvm::LLVMContext &ctx = mt.getLLVMContext ();
18071809
1808- #define CP_ASYNC_BULK_TENSOR_REDUCE (op, dim, is_im2col ) \
1809- is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1810- : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1810+ llvm::SmallVector<llvm::Value *> args;
18111811
1812- #define GET_CP_ASYNC_BULK_TENSOR_ID (op, dims, is_im2col ) \
1813- [&]() -> auto { \
1814- switch (dims) { \
1815- case 1 : \
1816- return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 1 , tile); \
1817- case 2 : \
1818- return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 2 , tile); \
1819- case 3 : \
1820- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 3 , is_im2col); \
1821- case 4 : \
1822- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 4 , is_im2col); \
1823- case 5 : \
1824- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 5 , is_im2col); \
1825- default : \
1826- llvm_unreachable (" Invalid TensorDim in CpAsyncBulkTensorReduceOp." ); \
1827- } \
1828- }()
1812+ // Arguments to the intrinsic:
1813+ // shared_mem_ptr, tmaDesc, tensorDims
1814+ // cache_hint(if applicable) and flag(boolean)
1815+ args.push_back (mt.lookupValue (thisOp.getSrcMem ()));
1816+ args.push_back (mt.lookupValue (thisOp.getTmaDescriptor ()));
1817+
1818+ for (Value v : thisOp.getCoordinates ())
1819+ args.push_back (mt.lookupValue (v));
1820+
1821+ mlir::Value cacheHint = thisOp.getL2CacheHint ();
1822+ const bool hasCacheHint = static_cast <bool >(cacheHint);
1823+ llvm::Value *i64ZeroValue =
1824+ llvm::ConstantInt::get (llvm::Type::getInt64Ty (ctx), 0 );
1825+ args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64ZeroValue);
1826+ args.push_back (builder.getInt1 (hasCacheHint));
1827+
1828+ const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
1829+
1830+ constexpr unsigned numRedKinds = 8 ; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
1831+ constexpr unsigned numLayouts = 2 ; // TILE, IM2COL
1832+ constexpr unsigned maxDim = 5 ; // 1D to 5D
1833+ using row = std::array<llvm::Intrinsic::ID, maxDim + 1 >;
1834+ using layoutTable = std::array<row, numLayouts>;
1835+ using fullTable = std::array<layoutTable, numRedKinds>;
1836+ static constexpr fullTable IDTable{
1837+ {// RedTy::ADD
1838+ {{{{notIntrinsic,
1839+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
1840+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
1841+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
1842+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
1843+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
1844+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1845+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
1846+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
1847+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
1848+ // RedTy::MIN
1849+ {{{{notIntrinsic,
1850+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
1851+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
1852+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
1853+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
1854+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
1855+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1856+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
1857+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
1858+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
1859+ // RedTy::MAX
1860+ {{{{notIntrinsic,
1861+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
1862+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
1863+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
1864+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
1865+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
1866+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1867+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
1868+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
1869+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
1870+ // RedTy::INC
1871+ {{{{notIntrinsic,
1872+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
1873+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
1874+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
1875+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
1876+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
1877+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1878+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
1879+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
1880+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
1881+ // RedTy::DEC
1882+ {{{{notIntrinsic,
1883+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
1884+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
1885+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
1886+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
1887+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
1888+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1889+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
1890+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
1891+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
1892+ // RedTy::AND
1893+ {{{{notIntrinsic,
1894+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
1895+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
1896+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
1897+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
1898+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
1899+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1900+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
1901+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
1902+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
1903+ // RedTy::OR
1904+ {{{{notIntrinsic,
1905+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
1906+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
1907+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
1908+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
1909+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
1910+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1911+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
1912+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
1913+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
1914+ // RedTy::XOR
1915+ {{{{notIntrinsic,
1916+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
1917+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
1918+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
1919+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
1920+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
1921+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1922+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
1923+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
1924+ llvm::Intrinsic::
1925+ nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
1926+
1927+ static_assert (getMaxEnumValForTMAReduxKind () == std::size (IDTable) - 1 ,
1928+ " TMAReduxKinds must match number of rows in IDTable" );
1929+
1930+ size_t redKind = static_cast <size_t >(thisOp.getRedKind ());
1931+ size_t mode = static_cast <size_t >(thisOp.getMode ());
1932+ size_t dim = thisOp.getCoordinates ().size ();
1933+
1934+ assert (redKind < IDTable.size () &&
1935+ " Invalid redKind for CpAsyncBulkTensorReduceOp" );
1936+ assert (mode < IDTable[redKind].size () &&
1937+ " Invalid mode for CpAsyncBulkTensorReduceOp" );
1938+ assert (dim < IDTable[redKind][mode].size () &&
1939+ " Invalid dim for CpAsyncBulkTensorReduceOp" );
1940+
1941+ llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
1942+
1943+ assert (intrinsicID != notIntrinsic &&
1944+ " Invalid intrinsic for CpAsyncBulkTensorReduceOp." );
18291945
1830- llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID (
1831- int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1832- using RedTy = NVVM::TMAReduxKind;
1833- switch (kind) {
1834- case RedTy::ADD:
1835- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_add, tensorDims, isIm2Col);
1836- case RedTy::MIN:
1837- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_min, tensorDims, isIm2Col);
1838- case RedTy::MAX:
1839- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_max, tensorDims, isIm2Col);
1840- case RedTy::INC:
1841- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_inc, tensorDims, isIm2Col);
1842- case RedTy::DEC:
1843- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_dec, tensorDims, isIm2Col);
1844- case RedTy::AND:
1845- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_and, tensorDims, isIm2Col);
1846- case RedTy::OR:
1847- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_or, tensorDims, isIm2Col);
1848- case RedTy::XOR:
1849- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_xor, tensorDims, isIm2Col);
1850- }
1851- llvm_unreachable (" Invalid Reduction Op for CpAsyncBulkTensorReduceOp" );
1946+ return {intrinsicID, std::move (args)};
18521947}
18531948
18541949#define _none
0 commit comments