@@ -95,27 +95,44 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
9595
9696// -----
9797
98+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 16 }>
9899#shared1 = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ]}>
99100#smem = #ttg.shared_memory
100101#blocked4 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
101102#blocked8 = #ttg.blocked <{sizePerThread = [1 , 1 , 1 , 2 , 4 ], threadsPerWarp = [1 , 1 , 16 , 2 , 1 ], warpsPerCTA = [2 , 1 , 2 , 1 , 1 ], order = [4 , 3 , 2 , 1 , 0 ]}>
102103#blocked9 = #ttg.blocked <{sizePerThread = [1 , 2 , 1 , 1 , 4 ], threadsPerWarp = [1 , 2 , 16 , 1 , 1 ], warpsPerCTA = [2 , 1 , 2 , 1 , 1 ], order = [4 , 1 , 2 , 3 , 0 ]}>
103104#blocked10 = #ttg.blocked <{sizePerThread = [1 , 1 , 1 , 1 , 4 ], threadsPerWarp = [1 , 1 , 32 , 1 , 1 ], warpsPerCTA = [1 , 1 , 1 , 1 , 4 ], order = [4 , 3 , 2 , 1 , 0 ]}>
104105#blocked11 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
106+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = true >
105107#tmem_scales = #ttng.tensor_memory_scales_encoding <>
106108module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
107- // CHECK-LABEL: @inject_tmem_copy
108- // CHECK: ttng.tmem_alloc {{.*}}, mutable
109- // CHECK: ttng.tmem_copy
109+ // CHECK-LABEL: @scales_in_shmem
110+ // CHECK: %[[A_LA:.*]] = ttg.local_alloc
111+ // CHECK: %[[B_LA:.*]] = ttg.local_alloc
112+ // CHECK: ttng.tc_gen5_mma_scaled {{.*}}, %[[A_LA]], %[[B_LA]],
110113
111- tt.func public @inject_tmem_copy (%scale: tensor <2 x512 x!tt.ptr <i8 >, #blocked4 > {tt.contiguity = 16 : i32 , tt.divisibility = 16 : i32 }) attributes {noinline = false } {
112- %75 = ttg.local_alloc : () -> !ttg.memdesc <2 x512 xi8 , #shared1 , #smem , mutable >
113- %180 = ttg.local_load %75 : !ttg.memdesc <2 x512 xi8 , #shared1 , #smem , mutable , 3 x2 x512 > -> tensor <2 x512 xi8 , #blocked4 >
114- %183 = tt.reshape %180 : tensor <2 x512 xi8 , #blocked4 > -> tensor <2 x1 x32 x4 x4 xi8 , #blocked8 >
115- %184 = tt.trans %183 {order = array<i32 : 0 , 3 , 2 , 1 , 4 >} : tensor <2 x1 x32 x4 x4 xi8 , #blocked8 > -> tensor <2 x4 x32 x1 x4 xi8 , #blocked9 >
116- %187 = ttg.convert_layout %184 : tensor <2 x4 x32 x1 x4 xi8 , #blocked9 > -> tensor <2 x4 x32 x1 x4 xi8 , #blocked10 >
117- %188 = tt.reshape %187 : tensor <2 x4 x32 x1 x4 xi8 , #blocked10 > -> tensor <256 x4 xi8 , #blocked11 >
118- %190 = ttng.tmem_alloc %188 : (tensor <256 x4 xi8 , #blocked11 >) -> !ttg.memdesc <256 x4 xi8 , #tmem_scales , #ttng.tensor_memory >
114+ tt.func public @scales_in_shmem (
115+ %scale: tensor <2 x512 x!tt.ptr <i8 >, #blocked4 > {tt.contiguity = 16 : i32 , tt.divisibility = 16 : i32 },
116+ %A_sh: !ttg.memdesc <128 x128 xf8 E5 M2 , #shared , #ttg.shared_memory >,
117+ %B_sh: !ttg.memdesc <128 x128 xf8 E5 M2 , #shared , #ttg.shared_memory >,
118+ %acc_tm: !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory >
119+ ) attributes {noinline = false } {
120+ %true = arith.constant true
121+ %A_la = ttg.local_alloc : () -> !ttg.memdesc <2 x512 xi8 , #shared1 , #smem , mutable >
122+ %B_la = ttg.local_alloc : () -> !ttg.memdesc <2 x512 xi8 , #shared1 , #smem , mutable >
123+ %A_ll = ttg.local_load %A_la : !ttg.memdesc <2 x512 xi8 , #shared1 , #smem , mutable , 3 x2 x512 > -> tensor <2 x512 xi8 , #blocked4 >
124+ %B_ll = ttg.local_load %B_la : !ttg.memdesc <2 x512 xi8 , #shared1 , #smem , mutable , 3 x2 x512 > -> tensor <2 x512 xi8 , #blocked4 >
125+ %A_r = tt.reshape %A_ll : tensor <2 x512 xi8 , #blocked4 > -> tensor <2 x1 x32 x4 x4 xi8 , #blocked8 >
126+ %B_r = tt.reshape %B_ll : tensor <2 x512 xi8 , #blocked4 > -> tensor <2 x1 x32 x4 x4 xi8 , #blocked8 >
127+ %A_tr = tt.trans %A_r {order = array<i32 : 0 , 3 , 2 , 1 , 4 >} : tensor <2 x1 x32 x4 x4 xi8 , #blocked8 > -> tensor <2 x4 x32 x1 x4 xi8 , #blocked9 >
128+ %B_tr = tt.trans %B_r {order = array<i32 : 0 , 3 , 2 , 1 , 4 >} : tensor <2 x1 x32 x4 x4 xi8 , #blocked8 > -> tensor <2 x4 x32 x1 x4 xi8 , #blocked9 >
129+ %A_cv = ttg.convert_layout %A_tr : tensor <2 x4 x32 x1 x4 xi8 , #blocked9 > -> tensor <2 x4 x32 x1 x4 xi8 , #blocked10 >
130+ %B_cv = ttg.convert_layout %B_tr : tensor <2 x4 x32 x1 x4 xi8 , #blocked9 > -> tensor <2 x4 x32 x1 x4 xi8 , #blocked10 >
131+ %A_r2 = tt.reshape %A_cv : tensor <2 x4 x32 x1 x4 xi8 , #blocked10 > -> tensor <256 x4 xi8 , #blocked11 >
132+ %B_r2 = tt.reshape %B_cv : tensor <2 x4 x32 x1 x4 xi8 , #blocked10 > -> tensor <256 x4 xi8 , #blocked11 >
133+ %A_tm = ttng.tmem_alloc %A_r2 : (tensor <256 x4 xi8 , #blocked11 >) -> !ttg.memdesc <256 x4 xi8 , #tmem_scales , #ttng.tensor_memory >
134+ %B_tm = ttng.tmem_alloc %B_r2 : (tensor <256 x4 xi8 , #blocked11 >) -> !ttg.memdesc <256 x4 xi8 , #tmem_scales , #ttng.tensor_memory >
135+ ttng.tc_gen5_mma_scaled %A_sh , %B_sh , %acc_tm , %A_tm , %B_tm , %true , %true lhs = e5m2 rhs = e5m2 {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (!ttg.memdesc <128 x128 xf8 E5 M2 , #shared , #ttg.shared_memory >, !ttg.memdesc <128 x128 xf8 E5 M2 , #shared , #ttg.shared_memory >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory >, !ttg.memdesc <256 x4 xi8 , #tmem_scales , #ttng.tensor_memory >, !ttg.memdesc <256 x4 xi8 , #tmem_scales , #ttng.tensor_memory >, i1 , i1 ) -> ()
119136 tt.return
120137}
121138}
0 commit comments