@@ -30,13 +30,10 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
3030 // CHECK-NEXT: [[AREF2:%.*]] = nvws.aref.create [[AREF_BUF2]]
3131 %1 = scf.for %arg5 = %c0_i32 to %arg0 step %c1_i32 iter_args (%arg6 = %0 ) -> (!ttg.async.token ) : i32 {
3232 %2 = arith.muli %arg5 , %c64_i32 {loop.cluster = 1 : i32 , loop.stage = 0 : i32 } : i32
33- // CHECK: [[C_ZERO1:%.*]] = arith.constant {ttg.partition = 2 : i32} 0
34- // CHECK: [[PUT_BUF1:%.*]], [[TOKEN1:%.*]] = nvws.aref.put.enter [[AREF1]][[[C_ZERO1]], [[C_ZERO1]]] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = 2 : i32}
33+ // CHECK: [[PUT_BUF1:%.*]], [[TOKEN1:%.*]] = nvws.aref.put.enter [[AREF1]] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = 2 : i32}
3534 // CHECK-NEXT: nvws.descriptor_load {{.*}} 16384 [[PUT_BUF1]]
36- // CHECK: [[C_ZERO2:%.*]] = arith.constant {ttg.partition = 2 : i32} 0
37- // CHECK: nvws.aref.put.exit [[AREF1]][[[C_ZERO2]]], [[TOKEN1]] [#nvws.async_op<tma_load>] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = 2 : i32}
38- // CHECK: [[C_ZERO3:%.*]] = arith.constant {ttg.partition = 1 : i32} 0
39- // CHECK: [[GET_BUF1:%.*]], [[GET_TOKEN1:%.*]] = nvws.aref.get.enter [[AREF1]][[[C_ZERO3]], [[C_ZERO3]]] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 1 : i32}
35+ // CHECK: nvws.aref.put.exit [[AREF1]], [[TOKEN1]] [#nvws.async_op<tma_load>] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = 2 : i32}
36+ // CHECK: [[GET_BUF1:%.*]], [[GET_TOKEN1:%.*]] = nvws.aref.get.enter [[AREF1]] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 1 : i32}
4037 %3 = tt.descriptor_load %arg3 [%arg1 , %2 ] {loop.cluster = 1 : i32 , loop.stage = 0 : i32 , ttg.partition = 2 : i32 } : !tt.tensordesc <tensor <128 x64 xf16 , #shared >> -> tensor <128 x64 xf16 , #blocked1 >
4138 // CHECK: [[PUT_BUF2:%.*]], [[TOKEN2:%.*]] = nvws.aref.put.enter [[AREF2]]
4239 // CHECK-NEXT: nvws.descriptor_load {{.*}} 16384 [[PUT_BUF2]]
@@ -51,9 +48,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
5148 %7 = ttg.memdesc_trans %6 {loop.cluster = 0 : i32 , loop.stage = 1 : i32 , order = array<i32 : 1 , 0 >, ttg.partition = 1 : i32 } : !ttg.memdesc <128 x64 xf16 , #shared , #smem > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #smem >
5249 // CHECK: ttng.tc_gen5_mma [[GET_BUF1]], [[RHS]], {{.*}}, {{.*}}, {{.*}}
5350 %8 = ttng.tc_gen5_mma %5 , %7 , %result [%arg6 ], %true , %true {loop.cluster = 0 : i32 , loop.stage = 1 : i32 , ttg.partition = 1 : i32 } : !ttg.memdesc <128 x64 xf16 , #shared , #smem >, !ttg.memdesc <64 x128 xf16 , #shared1 , #smem >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
54- // CHECK: nvws.aref.get.exit [[AREF2]][{{.*}}], [[GET_TOKEN2]]
55- // CHECK: [[C_ZERO4:%.*]] = arith.constant {ttg.partition = 1 : i32} 0
56- // CHECK: nvws.aref.get.exit [[AREF1]][[[C_ZERO4]]], [[GET_TOKEN1]] [#nvws.async_op<tc5mma>] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 1 : i32}
51+ // CHECK: nvws.aref.get.exit [[AREF2]], [[GET_TOKEN2]]
52+ // CHECK: nvws.aref.get.exit [[AREF1]], [[GET_TOKEN1]] [#nvws.async_op<tc5mma>] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 1 : i32}
5753 scf.yield %8 : !ttg.async.token
5854 } {tt.num_stages = 2 : i32 , tt.scheduled_max_stage = 1 : i32 , tt.warp_specialize , ttg.partition.stages = [0 : i32 , 1 : i32 , 0 : i32 ], ttg.warp_specialize.tag = 0 : i32 }
5955 %result_0 , %token_1 = ttng.tmem_load %result [%1 ] : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #blocked >
0 commit comments