@@ -2013,9 +2013,36 @@ class DefaultAttrsIntrinsicFlags<list<LLVMType> ret_types,
20132013 !foreach(i, !range(flags),
20142014 ImmArg<ArgIndex<!add(i, !size(param_types))>>))>;
20152015
2016- // Intrinsics for Tensor Copy using TMA
2017- // G2S -> From Global to Shared memory variants
2018- // S2G -> From Shared to Global memory variants
2016+ // TMA Tensor Copy Intrinsics: S2G -> From Shared to Global memory variants
2017+ foreach dim = 1...5 in {
2018+ defvar tensor_dim_args = !listsplat(llvm_i32_ty, dim);
2019+ foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in {
2020+ def int_nvvm_cp_async_bulk_tensor_s2g_ # mode # _ # dim # d :
2021+ DefaultAttrsIntrinsicFlags<[],
2022+ !listconcat([llvm_shared_ptr_ty, // src_smem_ptr
2023+ llvm_ptr_ty], // tensormap_ptr
2024+ tensor_dim_args, // actual tensor dims
2025+ [llvm_i64_ty]), // cache_hint
2026+ [llvm_i1_ty], // Flag for cache_hint
2027+ [IntrConvergent,
2028+ ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>,
2029+ NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>]>;
2030+
2031+ // Intrinsics for TMA Copy with reduction
2032+ foreach red_op = ["add", "min", "max", "inc", "dec", "and", "or", "xor"] in
2033+ def int_nvvm_cp_async_bulk_tensor_reduce_ # red_op # _ # mode # _ # dim # d :
2034+ DefaultAttrsIntrinsicFlags<[],
2035+ !listconcat([llvm_shared_ptr_ty, // src_smem_ptr
2036+ llvm_ptr_ty], // tensormap_ptr
2037+ tensor_dim_args, // actual tensor dims
2038+ [llvm_i64_ty]), // cache_hint
2039+ [llvm_i1_ty], // Flag for cache_hint
2040+ [IntrConvergent, ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>,
2041+ NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>]>;
2042+ }
2043+ }
2044+
2045+ // TMA Tensor Copy Intrinsics: G2S -> From Global to Shared memory variants
20192046foreach dim = 1...5 in {
20202047 defvar tensor_dim_args = !listsplat(llvm_i32_ty, dim);
20212048
@@ -2045,17 +2072,6 @@ foreach dim = 1...5 in {
20452072 def int_nvvm_cp_async_bulk_tensor_g2s_ # mode # _ # dim # d :
20462073 DefaultAttrsIntrinsicFlags<[], g2s_params, g2s_flags, g2s_props>;
20472074
2048- def int_nvvm_cp_async_bulk_tensor_s2g_ # mode # _ # dim # d :
2049- DefaultAttrsIntrinsicFlags<[],
2050- !listconcat([llvm_shared_ptr_ty, // src_smem_ptr
2051- llvm_ptr_ty], // tensormap_ptr
2052- tensor_dim_args, // actual tensor dims
2053- [llvm_i64_ty]), // cache_hint
2054- [llvm_i1_ty], // Flag for cache_hint
2055- [IntrConvergent,
2056- ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>,
2057- NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>]>;
2058-
20592075 def int_nvvm_cp_async_bulk_tensor_prefetch_ # mode # _ # dim # d :
20602076 DefaultAttrsIntrinsicFlags<[],
20612077 !listconcat([llvm_ptr_ty], // tensormap_ptr
@@ -2065,18 +2081,6 @@ foreach dim = 1...5 in {
20652081 [llvm_i1_ty], // Flag for cache_hint
20662082 [IntrConvergent,
20672083 ReadOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>]>;
2068-
2069- // Intrinsics for TMA Copy with reduction
2070- foreach red_op = ["add", "min", "max", "inc", "dec", "and", "or", "xor"] in
2071- def int_nvvm_cp_async_bulk_tensor_reduce_ # red_op # _ # mode # _ # dim # d :
2072- DefaultAttrsIntrinsicFlags<[],
2073- !listconcat([llvm_shared_ptr_ty, // src_smem_ptr
2074- llvm_ptr_ty], // tensormap_ptr
2075- tensor_dim_args, // actual tensor dims
2076- [llvm_i64_ty]), // cache_hint
2077- [llvm_i1_ty], // Flag for cache_hint
2078- [IntrConvergent, ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>,
2079- NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>]>;
20802084 }
20812085}
20822086
0 commit comments