Skip to content

Commit 02ee37d

Browse files
committed
remove macro
1 parent b715372 commit 02ee37d

File tree

1 file changed

+92
-55
lines changed

1 file changed

+92
-55
lines changed

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 92 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,31 +1634,6 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
16341634
return {id, std::move(args)};
16351635
}
16361636

1637-
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1638-
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1639-
1640-
#define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1641-
is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1642-
: CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1643-
1644-
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1645-
[&]() -> auto { \
1646-
switch (dims) { \
1647-
case 1: \
1648-
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1649-
case 2: \
1650-
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1651-
case 3: \
1652-
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1653-
case 4: \
1654-
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1655-
case 5: \
1656-
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1657-
default: \
1658-
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1659-
} \
1660-
}()
1661-
16621637
NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
16631638
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
16641639
auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
@@ -1682,37 +1657,99 @@ NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
16821657
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64ZeroValue);
16831658
args.push_back(builder.getInt1(hasCacheHint));
16841659

1685-
llvm::Intrinsic::ID intrinsicID;
1686-
int tensorDims = thisOp.getCoordinates().size();
1687-
bool isIm2Col = thisOp.getMode() == NVVM::TMAStoreMode::IM2COL;
1660+
const unsigned NI = llvm::Intrinsic::not_intrinsic;
1661+
static constexpr llvm::Intrinsic::ID IDTable[][2][6] = {
1662+
// RedTy::ADD
1663+
{{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
1664+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
1665+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
1666+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
1667+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d},
1668+
{NI, NI, NI,
1669+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
1670+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
1671+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}},
1672+
// RedTy::MIN
1673+
{{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
1674+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
1675+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
1676+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
1677+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d},
1678+
{NI, NI, NI,
1679+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
1680+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
1681+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}},
1682+
// RedTy::MAX
1683+
{{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
1684+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
1685+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
1686+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
1687+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d},
1688+
{NI, NI, NI,
1689+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
1690+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
1691+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}},
1692+
// RedTy::INC
1693+
{{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
1694+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
1695+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
1696+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
1697+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d},
1698+
{NI, NI, NI,
1699+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
1700+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
1701+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}},
1702+
// RedTy::DEC
1703+
{{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
1704+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
1705+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
1706+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
1707+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d},
1708+
{NI, NI, NI,
1709+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
1710+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
1711+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}},
1712+
// RedTy::AND
1713+
{{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
1714+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
1715+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
1716+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
1717+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d},
1718+
{NI, NI, NI,
1719+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
1720+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
1721+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}},
1722+
// RedTy::OR
1723+
{{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
1724+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
1725+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
1726+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
1727+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d},
1728+
{NI, NI, NI,
1729+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
1730+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
1731+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}},
1732+
// RedTy::XOR
1733+
{{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
1734+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
1735+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
1736+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
1737+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d},
1738+
{NI, NI, NI,
1739+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
1740+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
1741+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}},
1742+
};
16881743

1689-
using RedTy = NVVM::TMAReduxKind;
1690-
switch (thisOp.getRedKind()) {
1691-
case RedTy::ADD:
1692-
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1693-
break;
1694-
case RedTy::MIN:
1695-
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1696-
break;
1697-
case RedTy::MAX:
1698-
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1699-
break;
1700-
case RedTy::INC:
1701-
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1702-
break;
1703-
case RedTy::DEC:
1704-
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1705-
break;
1706-
case RedTy::AND:
1707-
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1708-
break;
1709-
case RedTy::OR:
1710-
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1711-
break;
1712-
case RedTy::XOR:
1713-
intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1714-
break;
1715-
}
1744+
static_assert(getMaxEnumValForTMAReduxKind() == std::size(IDTable) - 1,
1745+
"TMAReduxKinds must match number of rows in IDTable");
1746+
1747+
size_t redKind = static_cast<size_t>(thisOp.getRedKind());
1748+
size_t mode = static_cast<size_t>(thisOp.getMode());
1749+
size_t dim = thisOp.getCoordinates().size();
1750+
llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
1751+
if (intrinsicID == llvm::Intrinsic::not_intrinsic)
1752+
llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorReduceOp.");
17161753

17171754
return {intrinsicID, std::move(args)};
17181755
}

0 commit comments

Comments
 (0)