Skip to content

Commit 925ea55

Browse files
authored
[MLIR][NVVM] Fix undef in cp.async.bulk.tensor.reduce Op (#157423)
This change: - Moves the LLVMIR lowering code of the NVVM dialect `cp.async.bulk.tensor.reduce` Op to `NVVMDialect.cpp`. - Fixes the usage of `undef` in the lowering since it is now deprecated. - Removes macros to use a table to look up intrinsics instead. The tests are updated accordingly.
1 parent 2b90eb8 commit 925ea55

File tree

3 files changed

+209
-132
lines changed

3 files changed

+209
-132
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3207,35 +3207,17 @@ def NVVM_CpAsyncBulkTensorReduceOp :
32073207
}];
32083208

32093209
let extraClassDeclaration = [{
3210-
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims,
3211-
NVVM::TMAReduxKind kind,
3212-
bool isIm2Col);
3210+
static mlir::NVVM::IDArgPair
3211+
getIntrinsicIDAndArgs(Operation &op,
3212+
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase& builder);
32133213
}];
32143214

32153215
let hasVerifier = 1;
32163216

32173217
string llvmBuilder = [{
3218-
// Arguments to the intrinsic:
3219-
// shared_mem_ptr, tmaDesc, tensorDims
3220-
// cache_hint(if applicable) and flag(boolean)
3221-
llvm::SmallVector<llvm::Value *> translatedOperands;
3222-
translatedOperands.push_back($srcMem);
3223-
translatedOperands.push_back($tmaDescriptor);
3224-
3225-
for (auto v : op.getCoordinates())
3226-
translatedOperands.push_back(moduleTranslation.lookupValue(v));
3227-
3228-
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
3229-
auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));
3230-
3231-
bool isCacheHint = op.getL2CacheHint() ? true : false;
3232-
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
3233-
translatedOperands.push_back(builder.getInt1(isCacheHint));
3234-
3235-
auto intId = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicID(
3236-
op.getCoordinates().size(), $redKind,
3237-
(op.getMode() == NVVM::TMAStoreMode::IM2COL));
3238-
createIntrinsicCall(builder, intId, translatedOperands);
3218+
auto [id, args] = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
3219+
*op, moduleTranslation, builder);
3220+
createIntrinsicCall(builder, id, args);
32393221
}];
32403222
}
32413223

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 139 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,53 +1802,148 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
18021802
return {id, std::move(args)};
18031803
}
18041804

1805-
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1806-
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1805+
NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
1806+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1807+
auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
1808+
llvm::LLVMContext &ctx = mt.getLLVMContext();
18071809

1808-
#define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1809-
is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1810-
: CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1810+
llvm::SmallVector<llvm::Value *> args;
18111811

1812-
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1813-
[&]() -> auto { \
1814-
switch (dims) { \
1815-
case 1: \
1816-
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1817-
case 2: \
1818-
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1819-
case 3: \
1820-
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1821-
case 4: \
1822-
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1823-
case 5: \
1824-
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1825-
default: \
1826-
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1827-
} \
1828-
}()
1812+
// Arguments to the intrinsic:
1813+
// shared_mem_ptr, tmaDesc, tensorDims
1814+
// cache_hint(if applicable) and flag(boolean)
1815+
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1816+
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1817+
1818+
for (Value v : thisOp.getCoordinates())
1819+
args.push_back(mt.lookupValue(v));
1820+
1821+
mlir::Value cacheHint = thisOp.getL2CacheHint();
1822+
const bool hasCacheHint = static_cast<bool>(cacheHint);
1823+
llvm::Value *i64ZeroValue =
1824+
llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
1825+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
1826+
args.push_back(builder.getInt1(hasCacheHint));
1827+
1828+
const llvm::Intrinsic::ID notIntrinsic = llvm::Intrinsic::not_intrinsic;
1829+
1830+
constexpr unsigned numRedKinds = 8; // ADD, MIN, MAX, INC, DEC, AND, OR, XOR
1831+
constexpr unsigned numLayouts = 2; // TILE, IM2COL
1832+
constexpr unsigned maxDim = 5; // 1D to 5D
1833+
using row = std::array<llvm::Intrinsic::ID, maxDim + 1>;
1834+
using layoutTable = std::array<row, numLayouts>;
1835+
using fullTable = std::array<layoutTable, numRedKinds>;
1836+
static constexpr fullTable IDTable{
1837+
{// RedTy::ADD
1838+
{{{{notIntrinsic,
1839+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
1840+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
1841+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
1842+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
1843+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d}},
1844+
{{notIntrinsic, notIntrinsic, notIntrinsic,
1845+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
1846+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
1847+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}}}},
1848+
// RedTy::MIN
1849+
{{{{notIntrinsic,
1850+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
1851+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
1852+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
1853+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
1854+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d}},
1855+
{{notIntrinsic, notIntrinsic, notIntrinsic,
1856+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
1857+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
1858+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}}}},
1859+
// RedTy::MAX
1860+
{{{{notIntrinsic,
1861+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
1862+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
1863+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
1864+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
1865+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d}},
1866+
{{notIntrinsic, notIntrinsic, notIntrinsic,
1867+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
1868+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
1869+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}}}},
1870+
// RedTy::INC
1871+
{{{{notIntrinsic,
1872+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
1873+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
1874+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
1875+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
1876+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d}},
1877+
{{notIntrinsic, notIntrinsic, notIntrinsic,
1878+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
1879+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
1880+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}}}},
1881+
// RedTy::DEC
1882+
{{{{notIntrinsic,
1883+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
1884+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
1885+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
1886+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
1887+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d}},
1888+
{{notIntrinsic, notIntrinsic, notIntrinsic,
1889+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
1890+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
1891+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}}}},
1892+
// RedTy::AND
1893+
{{{{notIntrinsic,
1894+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
1895+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
1896+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
1897+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
1898+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d}},
1899+
{{notIntrinsic, notIntrinsic, notIntrinsic,
1900+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
1901+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
1902+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}}}},
1903+
// RedTy::OR
1904+
{{{{notIntrinsic,
1905+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
1906+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
1907+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
1908+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
1909+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d}},
1910+
{{notIntrinsic, notIntrinsic, notIntrinsic,
1911+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
1912+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
1913+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}}}},
1914+
// RedTy::XOR
1915+
{{{{notIntrinsic,
1916+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
1917+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
1918+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
1919+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
1920+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d}},
1921+
{{notIntrinsic, notIntrinsic, notIntrinsic,
1922+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
1923+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
1924+
llvm::Intrinsic::
1925+
nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}}}}}};
1926+
1927+
static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
1928+
"TMAReduxKinds must match number of rows in IDTable");
1929+
1930+
size_t redKind = static_cast<size_t>(thisOp.getRedKind());
1931+
size_t mode = static_cast<size_t>(thisOp.getMode());
1932+
size_t dim = thisOp.getCoordinates().size();
1933+
1934+
assert(redKind < IDTable.size() &&
1935+
"Invalid redKind for CpAsyncBulkTensorReduceOp");
1936+
assert(mode < IDTable[redKind].size() &&
1937+
"Invalid mode for CpAsyncBulkTensorReduceOp");
1938+
assert(dim < IDTable[redKind][mode].size() &&
1939+
"Invalid dim for CpAsyncBulkTensorReduceOp");
1940+
1941+
llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
1942+
1943+
assert(intrinsicID != notIntrinsic &&
1944+
"Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
18291945

1830-
llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1831-
int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1832-
using RedTy = NVVM::TMAReduxKind;
1833-
switch (kind) {
1834-
case RedTy::ADD:
1835-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1836-
case RedTy::MIN:
1837-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1838-
case RedTy::MAX:
1839-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1840-
case RedTy::INC:
1841-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1842-
case RedTy::DEC:
1843-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1844-
case RedTy::AND:
1845-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1846-
case RedTy::OR:
1847-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1848-
case RedTy::XOR:
1849-
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1850-
}
1851-
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1946+
return {intrinsicID, std::move(args)};
18521947
}
18531948

18541949
#define _none

0 commit comments

Comments
 (0)