@@ -158,3 +158,52 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
158
158
tt.return %res : tensor <8 x16 xf16 >
159
159
}
160
160
}
161
+
162
+
163
+ // -----
164
+
165
+ #dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [2 , 2 ], repCluster = [1 , 1 ]}>
166
+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
167
+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
168
+ #dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#dpas , kWidth =1 }>
169
+ #dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#dpas , kWidth =2 }>
170
+ #smem = #ttg.shared_memory
171
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
172
+ // CHECK-LABEL: matmul_tf32dot
173
+ tt.func @matmul_tf32dot (%ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 },
174
+ %a: !ttg.memdesc <32 x16 xf32 , #shared , #smem >, %b: !ttg.memdesc <16 x32 xf32 , #shared , #smem >) {
175
+ %cst = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf32 , #dpas >
176
+ %a_mat = ttg.local_load %a : !ttg.memdesc <32 x16 xf32 , #shared , #smem > -> tensor <32 x16 xf32 , #dot_operand_a >
177
+ %b_mat = ttg.local_load %b : !ttg.memdesc <16 x32 xf32 , #shared , #smem > -> tensor <16 x32 xf32 , #dot_operand_b >
178
+
179
+ // expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}}
180
+ %28 = tt.dot %a_mat , %b_mat , %cst , inputPrecision = tf32 : tensor <32 x16 xf32 , #dot_operand_a > * tensor <16 x32 xf32 , #dot_operand_b > -> tensor <32 x32 xf32 , #dpas >
181
+ %38 = ttg.convert_layout %28 : tensor <32 x32 xf32 , #dpas > -> tensor <32 x32 xf32 , #blocked >
182
+
183
+ tt.return
184
+ }
185
+ }
186
+
187
+ // -----
188
+
189
+ #dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 1 , threadsPerWarp = 16 , warpsPerCTA = [2 , 2 ], repCluster = [1 , 1 ]}>
190
+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
191
+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
192
+ #dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#dpas , kWidth =1 }>
193
+ // expected-error @below {{ttg.dot_op kWidth parameter must match the parent's opsPerChannel}}
194
+ #dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#dpas , kWidth =2 }>
195
+ #smem = #ttg.shared_memory
196
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
197
+ // CHECK-LABEL: matmul_tf32dot
198
+ tt.func @matmul_tf32dot (%ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 },
199
+ %a: !ttg.memdesc <32 x16 xf32 , #shared , #smem >, %b: !ttg.memdesc <16 x32 xf32 , #shared , #smem >) {
200
+ %cst = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf32 , #dpas >
201
+ %a_mat = ttg.local_load %a : !ttg.memdesc <32 x16 xf32 , #shared , #smem > -> tensor <32 x16 xf32 , #dot_operand_a >
202
+ %b_mat = ttg.local_load %b : !ttg.memdesc <16 x32 xf32 , #shared , #smem > -> tensor <16 x32 xf32 , #dot_operand_b >
203
+
204
+ %28 = tt.dot %a_mat , %b_mat , %cst , inputPrecision = tf32 : tensor <32 x16 xf32 , #dot_operand_a > * tensor <16 x32 xf32 , #dot_operand_b > -> tensor <32 x32 xf32 , #dpas >
205
+ %38 = ttg.convert_layout %28 : tensor <32 x32 xf32 , #dpas > -> tensor <32 x32 xf32 , #blocked >
206
+
207
+ tt.return
208
+ }
209
+ }
0 commit comments