@@ -95,7 +95,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
95
95
}
96
96
97
97
// -----
98
-
99
98
// 4 warps
100
99
// matmul: 128x32 @ 32x128 -> 128x128
101
100
#AL = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
@@ -107,6 +106,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
107
106
#A = #ttg.dot_op <{opIdx = 0 , parent = #C , kWidth =2 }>
108
107
#B = #ttg.dot_op <{opIdx = 1 , parent = #C , kWidth =2 }>
109
108
#shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 16 }>
109
+ #shared1 = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = true , elementBitWidth = 16 }>
110
110
#smem = #ttg.shared_memory
111
111
#tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = true >
112
112
module attributes {" ttg.num-warps" = 4 : i32 , " ttg.num-ctas" = 1 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
@@ -119,6 +119,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.targ
119
119
// CHECK: tt.fp_to_fp
120
120
// CHECK: ttng.wait_barrier
121
121
// CHECK: ttg.local_store
122
+ // CHECK: ttg.memdesc_trans
122
123
// CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}
123
124
// CHECK: ttg.async_copy_global_to_local
124
125
%a_ptr_splat = tt.splat %A : !tt.ptr <f8E4M3FN > -> tensor <128 x32 x!tt.ptr <f8E4M3FN >, #AL >
@@ -127,37 +128,38 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.targ
127
128
%a_offs = tt.broadcast %a_tmp1 : tensor <1 x32 xi32 , #AL > -> tensor <128 x32 xi32 , #AL >
128
129
%a_ptr_init = tt.addptr %a_ptr_splat , %a_offs : tensor <128 x32 x!tt.ptr <f8E4M3FN >, #AL >, tensor <128 x32 xi32 , #AL >
129
130
130
- %b_ptr_splat = tt.splat %B : !tt.ptr <f8E4M3FN > -> tensor <32 x 128 x !tt.ptr <f8E4M3FN >, #BL >
131
- %b_tmp0 = tt.make_range {end = 128 : i32 , start = 0 : i32 } : tensor <128 x i32 , #BLs0 >
132
- %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32 } : tensor <128 x i32 , #BLs0 > -> tensor <1 x 128 x i32 , #BL >
133
- %b_offs = tt.broadcast %b_tmp1 : tensor <1 x 128 x i32 , #BL > -> tensor <32 x 128 x i32 , #BL >
134
- %b_ptr_init = tt.addptr %b_ptr_splat , %b_offs : tensor <32 x 128 x !tt.ptr <f8E4M3FN >, #BL >, tensor <32 x 128 x i32 , #BL >
131
+ %b_ptr_splat = tt.splat %B : !tt.ptr <f8E4M3FN > -> tensor <128 x 32 x !tt.ptr <f8E4M3FN >, #BL >
132
+ %b_tmp0 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 x i32 , #BLs0 >
133
+ %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32 } : tensor <32 x i32 , #BLs0 > -> tensor <1 x 32 x i32 , #BL >
134
+ %b_offs = tt.broadcast %b_tmp1 : tensor <1 x 32 x i32 , #BL > -> tensor <128 x 32 x i32 , #BL >
135
+ %b_ptr_init = tt.addptr %b_ptr_splat , %b_offs : tensor <128 x 32 x !tt.ptr <f8E4M3FN >, #BL >, tensor <128 x 32 x i32 , #BL >
135
136
136
137
%true = arith.constant true
137
- %b_mask = arith.constant dense <true > : tensor <32 x 128 x i1 , #BL >
138
- %b_other = arith.constant dense <0.00e+00 > : tensor <32 x 128 xf 8 E 4 M 3 FN , #BL >
138
+ %b_mask = arith.constant dense <true > : tensor <128 x 32 x i1 , #BL >
139
+ %b_other = arith.constant dense <0.00e+00 > : tensor <128 x 32 xf 8 E 4 M 3 FN , #BL >
139
140
%c_init = arith.constant dense <0.00e+00 > : tensor <128 x128 xf32 , #C >
140
141
141
142
%a_off = arith.constant dense <4 > : tensor <128 x32 xi32 , #AL >
142
- %b_off = arith.constant dense <4 > : tensor <32 x 128 x i32 , #BL >
143
+ %b_off = arith.constant dense <4 > : tensor <128 x 32 x i32 , #BL >
143
144
144
- %loop:3 = scf.for %iv = %lb to %ub step %step iter_args (%a_ptr = %a_ptr_init , %b_ptr = %b_ptr_init , %prev_c = %c_init ) -> (tensor <128 x32 x!tt.ptr <f8E4M3FN >, #AL >, tensor <32 x 128 x !tt.ptr <f8E4M3FN >, #BL >, tensor <128 x128 xf32 , #C >) {
145
+ %loop:3 = scf.for %iv = %lb to %ub step %step iter_args (%a_ptr = %a_ptr_init , %b_ptr = %b_ptr_init , %prev_c = %c_init ) -> (tensor <128 x32 x!tt.ptr <f8E4M3FN >, #AL >, tensor <128 x 32 x !tt.ptr <f8E4M3FN >, #BL >, tensor <128 x128 xf32 , #C >) {
145
146
%a___ = tt.load %a_ptr : tensor <128 x32 x!tt.ptr <f8E4M3FN >, #AL >
146
147
%a__ = tt.fp_to_fp %a___ : tensor <128 x32 xf8 E4 M3 FN, #AL > -> tensor <128 x32 xf16 , #AL >
147
148
%a_ = ttg.convert_layout %a__ : tensor <128 x32 xf16 , #AL > -> tensor <128 x32 xf16 , #A >
148
- %b___ = tt.load %b_ptr , %b_mask , %b_other : tensor <32 x 128 x !tt.ptr <f8E4M3FN >, #BL >
149
- %b__ = tt.fp_to_fp %b___ : tensor <32 x 128 xf 8 E 4 M 3 FN , #BL > -> tensor <32 x 128 x f16 , #BL >
150
- %b_ = ttg.convert_layout %b__ : tensor <32 x 128 x f16 , #BL > -> tensor <32 x 128 x f16 , #B >
149
+ %b___ = tt.load %b_ptr , %b_mask , %b_other : tensor <128 x 32 x !tt.ptr <f8E4M3FN >, #BL >
150
+ %b__ = tt.fp_to_fp %b___ : tensor <128 x 32 xf 8 E 4 M 3 FN , #BL > -> tensor <128 x 32 x f16 , #BL >
151
+ %b_ = ttg.convert_layout %b__ : tensor <128 x 32 x f16 , #BL > -> tensor <128 x 32 x f16 , #B >
151
152
152
153
%a = ttg.local_alloc %a_ {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf16 , #A >) -> !ttg.memdesc <128 x32 xf16 , #shared , #smem >
153
- %b = ttg.local_alloc %b_ {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <32 x128 xf16 , #B >) -> !ttg.memdesc <32 x128 xf16 , #shared , #smem >
154
+ %b = ttg.local_alloc %b_ {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x32 xf16 , #B >) -> !ttg.memdesc <128 x32 xf16 , #shared , #smem >
155
+ %bt = ttg.memdesc_trans %b {loop.cluster = 1 : i32 , loop.stage = 2 : i32 , order = array<i32 : 1 , 0 >} : !ttg.memdesc <128 x32 xf16 , #shared , #smem > -> !ttg.memdesc <32 x128 xf16 , #shared1 , #smem >
154
156
%acc_tm , %acc_tok = ttng.tmem_alloc %prev_c : (tensor <128 x128 xf32 , #C >) -> (!ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >, !ttg.async.token )
155
- %mma_tok = ttng.tc_gen5_mma %a , %b , %acc_tm [%acc_tok ], %true , %true : !ttg.memdesc <128 x32 xf16 , #shared , #smem >, !ttg.memdesc <32 x128 xf16 , #shared , #smem >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
157
+ %mma_tok = ttng.tc_gen5_mma %a , %bt , %acc_tm [%acc_tok ], %true , %true : !ttg.memdesc <128 x32 xf16 , #shared , #smem >, !ttg.memdesc <32 x128 xf16 , #shared1 , #smem >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
156
158
%c , %load_tok = ttng.tmem_load %acc_tm [%mma_tok ] : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #C >
157
159
158
160
%next_a_ptr = tt.addptr %a_ptr , %a_off : tensor <128 x32 x!tt.ptr <f8E4M3FN >, #AL >, tensor <128 x32 xi32 , #AL >
159
- %next_b_ptr = tt.addptr %b_ptr , %b_off : tensor <32 x 128 x !tt.ptr <f8E4M3FN >, #BL >, tensor <32 x 128 x i32 , #BL >
160
- scf.yield %next_a_ptr , %next_b_ptr , %c : tensor <128 x32 x!tt.ptr <f8E4M3FN >, #AL >, tensor <32 x 128 x !tt.ptr <f8E4M3FN >, #BL >, tensor <128 x128 xf32 , #C >
161
+ %next_b_ptr = tt.addptr %b_ptr , %b_off : tensor <128 x 32 x !tt.ptr <f8E4M3FN >, #BL >, tensor <128 x 32 x i32 , #BL >
162
+ scf.yield %next_a_ptr , %next_b_ptr , %c : tensor <128 x32 x!tt.ptr <f8E4M3FN >, #AL >, tensor <128 x 32 x !tt.ptr <f8E4M3FN >, #BL >, tensor <128 x128 xf32 , #C >
161
163
}
162
164
tt.return %loop#2: tensor <128 x128 xf32 , #C >
163
165
}
0 commit comments