@@ -1088,3 +1088,62 @@ tt.func public @attention_forward(
10881088}
10891089
10901090}
1091+
1092+ // -----
1093+
1094+ #blocked = #ttg.blocked <{sizePerThread = [1 , 128 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
1095+ #blocked1 = #ttg.blocked <{sizePerThread = [1 ], threadsPerWarp = [32 ], warpsPerCTA = [4 ], order = [0 ]}>
1096+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 16 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
1097+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 8 }>
1098+ #shared1 = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = true , elementBitWidth = 8 }>
1099+ #smem = #ttg.shared_memory
1100+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , colStride = 1 >
1101+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
1102+ tt.func public @attention_persistent_inner_loop_kernel (%desc_q: !tt.tensordesc <tensor <128 x128 xf16 , #shared >>, %desc_q_0: i32 , %desc_q_1: i32 , %desc_q_2: i64 , %desc_q_3: i64 , %desc_k: !tt.tensordesc <tensor <128 x128 xf16 , #shared >>, %desc_k_4: i32 , %desc_k_5: i32 , %desc_k_6: i64 , %desc_k_7: i64 , %desc_v: !tt.tensordesc <tensor <128 x128 xf16 , #shared >>, %desc_v_8: i32 , %desc_v_9: i32 , %desc_v_10: i64 , %desc_v_11: i64 , %desc_acc: !tt.tensordesc <tensor <128 x128 xf16 , #shared >>, %desc_acc_12: i32 , %desc_acc_13: i32 , %desc_acc_14: i64 , %desc_acc_15: i64 , %l_i_ptr: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %m_i_ptr: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %M: i32 {tt.divisibility = 16 : i32 }, %N: i32 {tt.divisibility = 16 : i32 }, %qk_scale: f32 ) attributes {noinline = false } {
1103+ %false = arith.constant false
1104+ %true = arith.constant true
1105+ %c1_i32 = arith.constant 1 : i32
1106+ %c0_i32 = arith.constant 0 : i32
1107+ %c128_i32 = arith.constant 128 : i32
1108+ %cst = arith.constant dense <1.000000e+00 > : tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
1109+ %cst_16 = arith.constant dense <0xFF800000 > : tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
1110+ %prog_id = tt.get_program_id x : i32
1111+ %num_sm = tt.get_num_programs x : i32
1112+ %num_tiles = arith.divsi %M , %c128_i32 : i32
1113+ %tiles_per_sm = arith.divsi %num_tiles , %num_sm : i32
1114+ %tile_idx = scf.for %_ = %c0_i32 to %tiles_per_sm step %c1_i32 iter_args (%tile_idx_20 = %prog_id ) -> (i32 ) : i32 {
1115+ %off_m = arith.muli %tile_idx_20 , %c128_i32 : i32
1116+ %q = tt.descriptor_load %desc_q [%off_m , %c0_i32 ] : !tt.tensordesc <tensor <128 x128 xf16 , #shared >> -> tensor <128 x128 xf16 , #blocked2 >
1117+ %q_21 = ttg.local_alloc %q : (tensor <128 x128 xf16 , #blocked2 >) -> !ttg.memdesc <128 x128 xf16 , #shared , #smem >
1118+ %qk_22 , %qk_23 = ttng.tmem_alloc : () -> (!ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token )
1119+ %acc , %acc_24 = ttng.tmem_alloc : () -> (!ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token )
1120+ %acc_26:4 = scf.for %acc_30 = %c0_i32 to %N step %c128_i32 iter_args (%arg28 = %cst_16 , %arg29 = %cst , %qk_31 = %qk_23 , %acc_32 = %acc_24 ) -> (tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #blocked }>>, tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #blocked }>>, !ttg.async.token , !ttg.async.token ) : i32 {
1121+ // CHECK: tt.descriptor_load {{.*}} {tt.latency = 2 : i32}
1122+ %k = tt.descriptor_load %desc_k [%acc_30 , %c0_i32 ] : !tt.tensordesc <tensor <128 x128 xf16 , #shared >> -> tensor <128 x128 xf16 , #blocked2 >
1123+ %k_33 = ttg.local_alloc %k : (tensor <128 x128 xf16 , #blocked2 >) -> !ttg.memdesc <128 x128 xf16 , #shared , #smem >
1124+ %k_34 = ttg.memdesc_trans %k_33 {order = array<i32 : 1 , 0 >} : !ttg.memdesc <128 x128 xf16 , #shared , #smem > -> !ttg.memdesc <128 x128 xf16 , #shared1 , #smem >
1125+ // CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {tt.latency = 2 : i32, tt.self_latency = 1 : i32}
1126+ %qk_35 = ttng.tc_gen5_mma %q_21 , %k_34 , %qk_22 [%qk_31 ], %false , %true : !ttg.memdesc <128 x128 xf16 , #shared , #smem >, !ttg.memdesc <128 x128 xf16 , #shared1 , #smem >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
1127+ %qk_36 , %qk_37 = ttng.tmem_load %qk_22 [%qk_35 ] : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #blocked >
1128+
1129+ %acc_47 , %p , %next_l_i , %row_max = " softmax_work" (%qk_36 , %arg29 , %arg28 ) : (tensor <128 x128 xf32 , #blocked >, tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #blocked }>>, tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #blocked }>>) -> (tensor <128 x128 xf32 , #blocked >, tensor <128 x128 xf16 , #blocked >, tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #blocked }>>, tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #blocked }>>)
1130+
1131+ %acc_48 , %acc_49 = ttng.tmem_load %acc [%acc_32 ] : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #blocked >
1132+ %acc_50 = arith.mulf %acc_48 , %acc_47 : tensor <128 x128 xf32 , #blocked >
1133+ %p_53 = ttg.local_alloc %p : (tensor <128 x128 xf16 , #blocked >) -> !ttg.memdesc <128 x128 xf16 , #shared , #smem >
1134+ %acc_54 = ttng.tmem_store %acc_50 , %acc [%acc_49 ], %true : tensor <128 x128 xf32 , #blocked > -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
1135+ // CHECK: tt.descriptor_load {{.*}} {tt.latency = 2 : i32}
1136+ %v = tt.descriptor_load %desc_v [%acc_30 , %c0_i32 ] : !tt.tensordesc <tensor <128 x128 xf16 , #shared >> -> tensor <128 x128 xf16 , #blocked2 >
1137+ %v_51 = ttg.local_alloc %v : (tensor <128 x128 xf16 , #blocked2 >) -> !ttg.memdesc <128 x128 xf16 , #shared , #smem >
1138+
1139+ // CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {tt.self_latency = 1 : i32}
1140+ %acc_55 = ttng.tc_gen5_mma %p_53 , %v_51 , %acc [%acc_54 ], %true , %true : !ttg.memdesc <128 x128 xf16 , #shared , #smem >, !ttg.memdesc <128 x128 xf16 , #shared , #smem >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
1141+
1142+ scf.yield %row_max , %next_l_i , %qk_37 , %acc_55 : tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #blocked }>>, tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #blocked }>>, !ttg.async.token , !ttg.async.token
1143+ }
1144+ %tile_idx_29 = arith.addi %tile_idx_20 , %num_sm : i32
1145+ scf.yield %tile_idx_29 : i32
1146+ } {tt.num_stages = 3 : i32 , tt.warp_specialize }
1147+ tt.return
1148+ }
1149+ }
0 commit comments