@@ -1641,28 +1641,23 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1641
1641
is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1642
1642
: CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1643
1643
1644
- #define GET_CP_ASYNC_BULK_TENSOR_ID (iid, op, dims, is_im2col ) \
1645
- switch (dims) { \
1646
- case 1 : \
1647
- iid = CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 1 , tile); \
1648
- break ; \
1649
- case 2 : \
1650
- iid = CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 2 , tile); \
1651
- break ; \
1652
- case 3 : \
1653
- iid = CP_ASYNC_BULK_TENSOR_REDUCE (op, 3 , is_im2col); \
1654
- break ; \
1655
- case 4 : \
1656
- iid = CP_ASYNC_BULK_TENSOR_REDUCE (op, 4 , is_im2col); \
1657
- break ; \
1658
- case 5 : \
1659
- iid = CP_ASYNC_BULK_TENSOR_REDUCE (op, 5 , is_im2col); \
1660
- break ; \
1661
- default : \
1662
- llvm_unreachable (" Invalid TensorDim in CpAsyncBulkTensorReduceOp." ); \
1663
- break ; \
1664
- } \
1665
- break ;
1644
+ #define GET_CP_ASYNC_BULK_TENSOR_ID (op, dims, is_im2col ) \
1645
+ [&]() -> auto { \
1646
+ switch (dims) { \
1647
+ case 1 : \
1648
+ return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 1 , tile); \
1649
+ case 2 : \
1650
+ return CP_ASYNC_BULK_TENSOR_REDUCE_MODE (op, 2 , tile); \
1651
+ case 3 : \
1652
+ return CP_ASYNC_BULK_TENSOR_REDUCE (op, 3 , is_im2col); \
1653
+ case 4 : \
1654
+ return CP_ASYNC_BULK_TENSOR_REDUCE (op, 4 , is_im2col); \
1655
+ case 5 : \
1656
+ return CP_ASYNC_BULK_TENSOR_REDUCE (op, 5 , is_im2col); \
1657
+ default : \
1658
+ llvm_unreachable (" Invalid TensorDim in CpAsyncBulkTensorReduceOp." ); \
1659
+ } \
1660
+ }()
1666
1661
1667
1662
NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs (
1668
1663
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
@@ -1677,41 +1672,49 @@ NVVM::IDArgPair CpAsyncBulkTensorReduceOp::getIntrinsicIDAndArgs(
1677
1672
args.push_back (mt.lookupValue (thisOp.getSrcMem ()));
1678
1673
args.push_back (mt.lookupValue (thisOp.getTmaDescriptor ()));
1679
1674
1680
- for (auto v : thisOp.getCoordinates ())
1675
+ for (Value v : thisOp.getCoordinates ())
1681
1676
args.push_back (mt.lookupValue (v));
1682
1677
1683
1678
mlir::Value cacheHint = thisOp.getL2CacheHint ();
1684
1679
const bool hasCacheHint = static_cast <bool >(cacheHint);
1685
- llvm::Value *i64Unused =
1680
+ llvm::Value *i64ZeroValue =
1686
1681
llvm::ConstantInt::get (llvm::Type::getInt64Ty (ctx), 0 );
1687
- args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64Unused );
1682
+ args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64ZeroValue );
1688
1683
args.push_back (builder.getInt1 (hasCacheHint));
1689
1684
1690
- llvm::Intrinsic::ID iid ;
1685
+ llvm::Intrinsic::ID intrinsicID ;
1691
1686
int tensorDims = thisOp.getCoordinates ().size ();
1692
1687
bool isIm2Col = thisOp.getMode () == NVVM::TMAStoreMode::IM2COL;
1693
1688
1694
1689
using RedTy = NVVM::TMAReduxKind;
1695
1690
switch (thisOp.getRedKind ()) {
1696
1691
case RedTy::ADD:
1697
- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_add, tensorDims, isIm2Col);
1692
+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_add, tensorDims, isIm2Col);
1693
+ break ;
1698
1694
case RedTy::MIN:
1699
- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_min, tensorDims, isIm2Col);
1695
+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_min, tensorDims, isIm2Col);
1696
+ break ;
1700
1697
case RedTy::MAX:
1701
- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_max, tensorDims, isIm2Col);
1698
+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_max, tensorDims, isIm2Col);
1699
+ break ;
1702
1700
case RedTy::INC:
1703
- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_inc, tensorDims, isIm2Col);
1701
+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_inc, tensorDims, isIm2Col);
1702
+ break ;
1704
1703
case RedTy::DEC:
1705
- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_dec, tensorDims, isIm2Col);
1704
+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_dec, tensorDims, isIm2Col);
1705
+ break ;
1706
1706
case RedTy::AND:
1707
- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_and, tensorDims, isIm2Col);
1707
+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_and, tensorDims, isIm2Col);
1708
+ break ;
1708
1709
case RedTy::OR:
1709
- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_or, tensorDims, isIm2Col);
1710
+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_or, tensorDims, isIm2Col);
1711
+ break ;
1710
1712
case RedTy::XOR:
1711
- GET_CP_ASYNC_BULK_TENSOR_ID (iid, reduce_xor, tensorDims, isIm2Col);
1713
+ intrinsicID = GET_CP_ASYNC_BULK_TENSOR_ID (reduce_xor, tensorDims, isIm2Col);
1714
+ break ;
1712
1715
}
1713
1716
1714
- return {iid , std::move (args)};
1717
+ return {intrinsicID , std::move (args)};
1715
1718
}
1716
1719
1717
1720
#define _none
0 commit comments