@@ -1802,53 +1802,148 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1802
1802
return {id, std::move (args)};
1803
1803
}
1804
1804
1805
- #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, dim, mode ) \
1806
- llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1805
+ NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs (
1806
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1807
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
1808
+ llvm::LLVMContext &ctx = mt.getLLVMContext ();
1807
1809
1808
- #define CP_ASYNC_BULK_TENSOR_REDUCE (op, dim, is_im2col ) \
1809
- is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1810
- : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1810
+ llvm::SmallVector<llvm::Value *> args;
1811
1811
1812
- #define GET_CP_ASYNC_BULK_TENSOR_ID (op, dims, is_im2col ) \
1813
- [&]() -> auto { \
1814
- switch (dims) { \
1815
- case 1 : \
1816
- return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 1 , tile); \
1817
- case 2 : \
1818
- return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 2 , tile); \
1819
- case 3 : \
1820
- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 3 , is_im2col); \
1821
- case 4 : \
1822
- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 4 , is_im2col); \
1823
- case 5 : \
1824
- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 5 , is_im2col); \
1825
- default : \
1826
- llvm_unreachable (" Invalid TensorDim in CpAsyncBulkTensorReduceOp." ); \
1827
- } \
1828
- }()
1812
+ // Arguments to the intrinsic:
1813
+ // shared_mem_ptr, tmaDesc, tensorDims
1814
+ // cache_hint(if applicable) and flag(boolean)
1815
+ args.push_back (mt.lookupValue (thisOp.getSrcMem ()));
1816
+ args.push_back (mt.lookupValue (thisOp.getTmaDescriptor ()));
1817
+
1818
+ for (Value v : thisOp.getCoordinates ())
1819
+ args.push_back (mt.lookupValue (v));
1820
+
1821
+ mlir::Value cacheHint = thisOp.getL2CacheHint ();
1822
+ const bool hasCacheHint = static_cast <bool >(cacheHint);
1823
+ llvm::Value *i64ZeroValue =
1824
+ llvm::ConstantInt::get (llvm::Type::getInt64Ty (ctx), 0 );
1825
+ args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64ZeroValue);
1826
+ args.push_back (builder.getInt1 (hasCacheHint));
1827
+
1828
+ const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
1829
+
1830
+ constexpr unsigned numRedKinds = 8 ; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
1831
+ constexpr unsigned numLayouts = 2 ; // TILE, IM2COL
1832
+ constexpr unsigned maxDim = 5 ; // 1D to 5D
1833
+ using row = std::array<llvm::Intrinsic::ID, maxDim + 1 >;
1834
+ using layoutTable = std::array<row, numLayouts>;
1835
+ using fullTable = std::array<layoutTable, numRedKinds>;
1836
+ static constexpr fullTable IDTable{
1837
+ {// RedTy::ADD
1838
+ {{{{notIntrinsic,
1839
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
1840
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
1841
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
1842
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
1843
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
1844
+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1845
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
1846
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
1847
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
1848
+ // RedTy::MIN
1849
+ {{{{notIntrinsic,
1850
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
1851
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
1852
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
1853
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
1854
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
1855
+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1856
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
1857
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
1858
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
1859
+ // RedTy::MAX
1860
+ {{{{notIntrinsic,
1861
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
1862
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
1863
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
1864
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
1865
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
1866
+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1867
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
1868
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
1869
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
1870
+ // RedTy::INC
1871
+ {{{{notIntrinsic,
1872
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
1873
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
1874
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
1875
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
1876
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
1877
+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1878
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
1879
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
1880
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
1881
+ // RedTy::DEC
1882
+ {{{{notIntrinsic,
1883
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
1884
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
1885
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
1886
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
1887
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
1888
+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1889
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
1890
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
1891
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
1892
+ // RedTy::AND
1893
+ {{{{notIntrinsic,
1894
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
1895
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
1896
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
1897
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
1898
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
1899
+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1900
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
1901
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
1902
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
1903
+ // RedTy::OR
1904
+ {{{{notIntrinsic,
1905
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
1906
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
1907
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
1908
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
1909
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
1910
+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1911
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
1912
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
1913
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
1914
+ // RedTy::XOR
1915
+ {{{{notIntrinsic,
1916
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
1917
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
1918
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
1919
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
1920
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
1921
+ {{notIntrinsic, notIntrinsic, notIntrinsic,
1922
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
1923
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
1924
+ llvm::Intrinsic::
1925
+ nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
1926
+
1927
+ static_assert (getMaxEnumValForTMAReduxKind () == std::size (IDTable) - 1 ,
1928
+ " TMAReduxKinds must match number of rows in IDTable" );
1929
+
1930
+ size_t redKind = static_cast <size_t >(thisOp.getRedKind ());
1931
+ size_t mode = static_cast <size_t >(thisOp.getMode ());
1932
+ size_t dim = thisOp.getCoordinates ().size ();
1933
+
1934
+ assert (redKind < IDTable.size () &&
1935
+ " Invalid redKind for CpAsyncBulkTensorReduceOp" );
1936
+ assert (mode < IDTable[redKind].size () &&
1937
+ " Invalid mode for CpAsyncBulkTensorReduceOp" );
1938
+ assert (dim < IDTable[redKind][mode].size () &&
1939
+ " Invalid dim for CpAsyncBulkTensorReduceOp" );
1940
+
1941
+ llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
1942
+
1943
+ assert (intrinsicID != notIntrinsic &&
1944
+ " Invalid intrinsic for CpAsyncBulkTensorReduceOp." );
1829
1945
1830
- llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID (
1831
- int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1832
- using RedTy = NVVM::TMAReduxKind;
1833
- switch (kind) {
1834
- case RedTy::ADD:
1835
- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_add, tensorDims, isIm2Col);
1836
- case RedTy::MIN:
1837
- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_min, tensorDims, isIm2Col);
1838
- case RedTy::MAX:
1839
- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_max, tensorDims, isIm2Col);
1840
- case RedTy::INC:
1841
- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_inc, tensorDims, isIm2Col);
1842
- case RedTy::DEC:
1843
- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_dec, tensorDims, isIm2Col);
1844
- case RedTy::AND:
1845
- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_and, tensorDims, isIm2Col);
1846
- case RedTy::OR:
1847
- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_or, tensorDims, isIm2Col);
1848
- case RedTy::XOR:
1849
- return GET_CP_ASYNC_BULK_TENSOR_ID (reduce_xor, tensorDims, isIm2Col);
1850
- }
1851
- llvm_unreachable (" Invalid Reduction Op for CpAsyncBulkTensorReduceOp" );
1946
+ return {intrinsicID, std::move (args)};
1852
1947
}
1853
1948
1854
1949
#define _none
0 commit comments