@@ -788,3 +788,57 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
788788 tt.return
789789 }
790790}
791+
792+ // -----
793+
794+ #blocked = #ttg.blocked <{sizePerThread = [1 , 128 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
795+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
796+ #blocked2 = #ttg.blocked <{sizePerThread = [1 ], threadsPerWarp = [32 ], warpsPerCTA = [4 ], order = [0 ]}>
797+ #blocked3 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
798+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 16 }>
799+ #smem = #ttg.shared_memory
800+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , colStride = 1 >
801+ module attributes {" ttg.num-warps" = 4 : i32 , ttg.target = " cuda:100" } {
802+ // CHECK-LABEL: @if_split_workaround
803+ tt.func @if_split_workaround (%arg0: !tt.tensordesc <tensor <1 x64 xf16 , #shared >>, %arg1: tensor <64 x128 x!tt.ptr <f16 >, #blocked3 > {tt.contiguity = dense <[1 , 64 ]> : tensor <2 xi32 >, tt.divisibility = dense <16 > : tensor <2 xi32 >}) {
804+ %c0_i32 = arith.constant 0 : i32
805+ %c1_i32 = arith.constant 1 : i32
806+ %true = arith.constant true
807+ %false = arith.constant false
808+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #blocked >
809+ %c32_i32 = arith.constant 32 : i32
810+ %result , %token = ttng.tmem_alloc : () -> (!ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token )
811+ %0 = ttng.tmem_store %cst , %result [%token ], %true : tensor <128 x128 xf32 , #blocked > -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
812+ // CHECK: scf.for
813+ %1:3 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args (%arg3 = %true , %arg4 = %arg1 , %arg5 = %0 ) -> (i1 , tensor <64 x128 x!tt.ptr <f16 >, #blocked3 >, !ttg.async.token ) : i32 {
814+ %2:3 = " get_offsets" (%arg2 ) {loop.cluster = 3 : i32 , loop.stage = 0 : i32 , ttg.partition = array<i32 : 1 , 2 >} : (i32 ) -> (i32 , tensor <64 x128 xi32 , #blocked3 >, i32 )
815+ %3 = tt.splat %2#0 {loop.cluster = 3 : i32 , loop.stage = 0 : i32 , ttg.partition = array<i32 : 2 >} : i32 -> tensor <128 xi32 , #blocked2 >
816+ %4 = tt.descriptor_gather %arg0 [%3 , %2 #2 ] {loop.cluster = 3 : i32 , loop.stage = 0 : i32 , ttg.partition = array<i32 : 2 >} : (!tt.tensordesc <tensor <1 x64 xf16 , #shared >>, tensor <128 xi32 , #blocked2 >, i32 ) -> tensor <128 x64 xf16 , #blocked1 >
817+ %5 = tt.addptr %arg4 , %2#1 {loop.cluster = 3 : i32 , loop.stage = 1 : i32 , tt.constancy = dense <1 > : tensor <2 xi32 >, tt.contiguity = dense <[1 , 64 ]> : tensor <2 xi32 >, tt.divisibility = dense <16 > : tensor <2 xi32 >, ttg.partition = array<i32 : 1 >} : tensor <64 x128 x!tt.ptr <f16 >, #blocked3 >, tensor <64 x128 xi32 , #blocked3 >
818+ %6 = tt.load %5 {loop.cluster = 3 : i32 , loop.stage = 1 : i32 , ttg.partition = array<i32 : 1 >} : tensor <64 x128 x!tt.ptr <f16 >, #blocked3 >
819+ %7 = ttg.local_alloc %4 {loop.cluster = 2 : i32 , loop.stage = 2 : i32 , ttg.partition = array<i32 : 2 >} : (tensor <128 x64 xf16 , #blocked1 >) -> !ttg.memdesc <128 x64 xf16 , #shared , #smem >
820+ %8 = ttg.local_alloc %6 {loop.cluster = 2 : i32 , loop.stage = 2 : i32 , ttg.partition = array<i32 : 1 >} : (tensor <64 x128 xf16 , #blocked3 >) -> !ttg.memdesc <64 x128 xf16 , #shared , #smem >
821+ // CHECK: tc_gen5_mma {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32
822+ %9 = ttng.tc_gen5_mma %7 , %8 , %result [%arg5 ], %arg3 , %true {loop.cluster = 2 : i32 , loop.stage = 2 : i32 , tt.self_latency = 1 : i32 , ttg.partition = array<i32 : 1 >} : !ttg.memdesc <128 x64 xf16 , #shared , #smem >, !ttg.memdesc <64 x128 xf16 , #shared , #smem >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
823+ %10 = arith.cmpi eq , %arg2 , %c0_i32 {loop.cluster = 1 : i32 , loop.stage = 3 : i32 , ttg.partition = array<i32 : 0 , 1 >} : i32
824+ %11 = arith.select %10 , %false , %true {loop.cluster = 1 : i32 , loop.stage = 3 : i32 , ttg.partition = array<i32 : 1 >} : i1
825+ // CHECK: scf.if
826+ // CHECK-NEXT: put.exit {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32
827+ // CHECK} {loop.cluster = 2 : i32, loop.stage = 2 : i32
828+ // CHECK: scf.if
829+ // CHECK: } {loop.cluster = 4 : i32, loop.stage = 3 : i32
830+ // CHECK: scf.if
831+ // CKECK-NEXT: put.enter {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32
832+ // CHECK: } {loop.cluster = 2 : i32, loop.stage = 2 : i32
833+ %12 = scf.if %10 -> (!ttg.async.token ) {
834+ %result_0 , %token_1 = ttng.tmem_load %result [%9 ] {ttg.partition = array<i32 : 0 >} : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #blocked >
835+ " acc_user" (%result_0 ) {ttg.partition = array<i32 : 0 >} : (tensor <128 x128 xf32 , #blocked >) -> ()
836+ scf.yield {ttg.partition = array<i32 : 0 , 1 >} %token_1 : !ttg.async.token
837+ } else {
838+ scf.yield {ttg.partition = array<i32 : 0 , 1 >} %9 : !ttg.async.token
839+ } {loop.cluster = 4 : i32 , loop.stage = 3 : i32 , ttg.partition = array<i32 : 0 , 1 >, ttg.partition.outputs = [array <i32 : 1 >]}
840+ scf.yield {ttg.partition = array<i32 : 0 , 1 , 2 >} %11 , %5 , %12 : i1 , tensor <64 x128 x!tt.ptr <f16 >, #blocked3 >, !ttg.async.token
841+ } {tt.disallow_acc_multi_buffer , tt.num_stages = 3 : i32 , tt.scheduled_max_stage = 3 : i32 , tt.warp_specialize , ttg.partition = array<i32 : 0 , 1 , 2 >, ttg.partition.outputs = [array <i32 : 1 >, array <i32 : 1 >, array <i32 : 1 >], ttg.partition.stages = [0 : i32 , 1 : i32 , 0 : i32 ], ttg.warp_specialize.tag = 2 : i32 }
842+ tt.return
843+ }
844+ }
0 commit comments