@@ -163,22 +163,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
163
163
// -----
164
164
165
165
#dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [2 , 2 ], repCluster = [1 , 1 ]}>
166
- #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
167
- #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
168
166
#dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#dpas , kWidth =1 }>
169
167
#dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#dpas , kWidth =2 }>
170
168
#smem = #ttg.shared_memory
171
169
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
172
170
// CHECK-LABEL: matmul_tf32dot
173
- tt.func @matmul_tf32dot (%ptr: !tt.ptr <f32 > { tt.divisibility = 16 : i32 } ,
174
- %a: !ttg.memdesc <32 x16 xf32 , #shared , #smem >, %b: !ttg.memdesc <16 x32 xf32 , #shared , #smem >) {
171
+ tt.func @matmul_tf32dot (%ptr: !tt.ptr <f32 >,
172
+ %a_mat:tensor <32 x16 xf32 , #dot_operand_a >, %b_mat:tensor <16 x32 xf32 , #dot_operand_b >) {
175
173
%cst = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf32 , #dpas >
176
- %a_mat = ttg.local_load %a : !ttg.memdesc <32 x16 xf32 , #shared , #smem > -> tensor <32 x16 xf32 , #dot_operand_a >
177
- %b_mat = ttg.local_load %b : !ttg.memdesc <16 x32 xf32 , #shared , #smem > -> tensor <16 x32 xf32 , #dot_operand_b >
178
174
179
175
// expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}}
180
176
%28 = tt.dot %a_mat , %b_mat , %cst , inputPrecision = tf32 : tensor <32 x16 xf32 , #dot_operand_a > * tensor <16 x32 xf32 , #dot_operand_b > -> tensor <32 x32 xf32 , #dpas >
181
- %38 = ttg.convert_layout %28 : tensor <32 x32 xf32 , #dpas > -> tensor <32 x32 xf32 , #blocked >
182
177
183
178
tt.return
184
179
}
@@ -187,22 +182,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
187
182
// -----
188
183
189
184
#dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 1 , threadsPerWarp = 16 , warpsPerCTA = [2 , 2 ], repCluster = [1 , 1 ]}>
190
- #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
191
- #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
192
185
#dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#dpas , kWidth =1 }>
193
186
// expected-error @below {{ttg.dot_op kWidth parameter must match the parent's opsPerChannel}}
194
187
#dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#dpas , kWidth =2 }>
195
188
#smem = #ttg.shared_memory
196
189
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
197
190
// CHECK-LABEL: matmul_tf32dot
198
- tt.func @matmul_tf32dot (%ptr: !tt.ptr <f32 > { tt.divisibility = 16 : i32 } ,
199
- %a: !ttg.memdesc <32 x16 xf32 , #shared , #smem >, %b: !ttg.memdesc <16 x32 xf32 , #shared , #smem >) {
191
+ tt.func @matmul_tf32dot (%ptr: !tt.ptr <f32 >,
192
+ %a_mat:tensor <32 x16 xf32 , #dot_operand_a >, %b_mat:tensor <16 x32 xf32 , #dot_operand_b >) {
200
193
%cst = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf32 , #dpas >
201
- %a_mat = ttg.local_load %a : !ttg.memdesc <32 x16 xf32 , #shared , #smem > -> tensor <32 x16 xf32 , #dot_operand_a >
202
- %b_mat = ttg.local_load %b : !ttg.memdesc <16 x32 xf32 , #shared , #smem > -> tensor <16 x32 xf32 , #dot_operand_b >
203
-
204
194
%28 = tt.dot %a_mat , %b_mat , %cst , inputPrecision = tf32 : tensor <32 x16 xf32 , #dot_operand_a > * tensor <16 x32 xf32 , #dot_operand_b > -> tensor <32 x32 xf32 , #dpas >
205
- %38 = ttg.convert_layout %28 : tensor <32 x32 xf32 , #dpas > -> tensor <32 x32 xf32 , #blocked >
206
195
207
196
tt.return
208
197
}
0 commit comments