@@ -1634,31 +1634,6 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1634
1634
return {id, std::move (args)};
1635
1635
}
1636
1636
1637
- #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, dim, mode ) \
1638
- llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1639
-
1640
- #define CP_ASYNC_BULK_TENSOR_REDUCE (op, dim, is_im2col ) \
1641
- is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1642
- : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1643
-
1644
- #define GET_CP_ASYNC_BULK_TENSOR_ID (op, dims, is_im2col ) \
1645
- [&]() -> auto { \
1646
- switch (dims) { \
1647
- case 1 : \
1648
- return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 1 , tile); \
1649
- case 2 : \
1650
- return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 2 , tile); \
1651
- case 3 : \
1652
- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 3 , is_im2col); \
1653
- case 4 : \
1654
- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 4 , is_im2col); \
1655
- case 5 : \
1656
- return CP_ASYNC_BULK_TENSOR_REDUCE (op, 5 , is_im2col); \
1657
- default : \
1658
- llvm_unreachable (" Invalid TensorDim in CpAsyncBulkTensorReduceOp." ); \
1659
- } \
1660
- }()
1661
-
1662
1637
NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs (
1663
1638
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1664
1639
auto thisOp = cast<NVVM::CpAsyncBulkTensorReduceOp>(op);
@@ -1682,37 +1657,99 @@ NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
1682
1657
args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64ZeroValue);
1683
1658
args.push_back (builder.getInt1 (hasCacheHint));
1684
1659
1685
- llvm::Intrinsic::ID intrinsicID;
1686
- int tensorDims = thisOp.getCoordinates ().size ();
1687
- bool isIm2Col = thisOp.getMode () == NVVM::TMAStoreMode::IM2COL;
1660
+ const unsigned NI = llvm::Intrinsic::not_intrinsic;
1661
+ static constexpr llvm::Intrinsic::ID IDTable[][2 ][6 ] = {
1662
+ // RedTy::ADD
1663
+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d,
1664
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d,
1665
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d,
1666
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_4d,
1667
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_5d},
1668
+ {NI, NI, NI,
1669
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_3d,
1670
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_4d,
1671
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_im2col_5d}},
1672
+ // RedTy::MIN
1673
+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_1d,
1674
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_2d,
1675
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_3d,
1676
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_4d,
1677
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_tile_5d},
1678
+ {NI, NI, NI,
1679
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_3d,
1680
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_4d,
1681
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_min_im2col_5d}},
1682
+ // RedTy::MAX
1683
+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_1d,
1684
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_2d,
1685
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_3d,
1686
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_4d,
1687
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_tile_5d},
1688
+ {NI, NI, NI,
1689
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_3d,
1690
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_4d,
1691
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_max_im2col_5d}},
1692
+ // RedTy::INC
1693
+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_1d,
1694
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_2d,
1695
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_3d,
1696
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_4d,
1697
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_tile_5d},
1698
+ {NI, NI, NI,
1699
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_3d,
1700
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_4d,
1701
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_inc_im2col_5d}},
1702
+ // RedTy::DEC
1703
+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_1d,
1704
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_2d,
1705
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_3d,
1706
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_4d,
1707
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_tile_5d},
1708
+ {NI, NI, NI,
1709
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_3d,
1710
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_4d,
1711
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_dec_im2col_5d}},
1712
+ // RedTy::AND
1713
+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_1d,
1714
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_2d,
1715
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_3d,
1716
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_4d,
1717
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_tile_5d},
1718
+ {NI, NI, NI,
1719
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_3d,
1720
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_4d,
1721
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_and_im2col_5d}},
1722
+ // RedTy::OR
1723
+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_1d,
1724
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_2d,
1725
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_3d,
1726
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_4d,
1727
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_tile_5d},
1728
+ {NI, NI, NI,
1729
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_3d,
1730
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_4d,
1731
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_or_im2col_5d}},
1732
+ // RedTy::XOR
1733
+ {{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_1d,
1734
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_2d,
1735
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_3d,
1736
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_4d,
1737
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_tile_5d},
1738
+ {NI, NI, NI,
1739
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_3d,
1740
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_4d,
1741
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_reduce_xor_im2col_5d}},
1742
+ };
1688
1743
1689
- using RedTy = NVVM::TMAReduxKind;
1690
- switch (thisOp.getRedKind ()) {
1691
- case RedTy::ADD:
1692
- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_add, tensorDims, isIm2Col);
1693
- break ;
1694
- case RedTy::MIN:
1695
- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_min, tensorDims, isIm2Col);
1696
- break ;
1697
- case RedTy::MAX:
1698
- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_max, tensorDims, isIm2Col);
1699
- break ;
1700
- case RedTy::INC:
1701
- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_inc, tensorDims, isIm2Col);
1702
- break ;
1703
- case RedTy::DEC:
1704
- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_dec, tensorDims, isIm2Col);
1705
- break ;
1706
- case RedTy::AND:
1707
- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_and, tensorDims, isIm2Col);
1708
- break ;
1709
- case RedTy::OR:
1710
- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_or, tensorDims, isIm2Col);
1711
- break ;
1712
- case RedTy::XOR:
1713
- intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_xor, tensorDims, isIm2Col);
1714
- break ;
1715
- }
1744
+ static_assert (getMaxEnumValForTMAReduxKind () == std::size (IDTable) - 1 ,
1745
+ " TMAReduxKinds must match number of rows in IDTable" );
1746
+
1747
+ size_t redKind = static_cast <size_t >(thisOp.getRedKind ());
1748
+ size_t mode = static_cast <size_t >(thisOp.getMode ());
1749
+ size_t dim = thisOp.getCoordinates ().size ();
1750
+ llvm::Intrinsic::ID intrinsicID = IDTable[redKind][mode][dim];
1751
+ if (intrinsicID == llvm::Intrinsic::not_intrinsic)
1752
+ llvm_unreachable (" Invalid intrinsic for CpAsyncBulkTensorReduceOp." );
1716
1753
1717
1754
return {intrinsicID, std::move (args)};
1718
1755
}
0 commit comments