@@ -333,3 +333,38 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
333
333
tt.return
334
334
}
335
335
}
336
+
337
+ // -----
338
+
339
+ // CHECK: #[[$DPAS0:.+]] = #ttig.dpas<{repeatCount = 1, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [1, 16], B = [16, 16], C = [1, 16]}>
340
+ // CHECK: #[[$DPAS1:.+]] = #ttig.dpas<{repeatCount = 2, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [2, 16], B = [16, 16], C = [2, 16]}>
341
+ // CHECK: #[[$DPAS2:.+]] = #ttig.dpas<{repeatCount = 4, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [4, 16], B = [16, 16], C = [4, 16]}>
342
+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
343
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 1 : i32 , " ttg.threads-per-warp" = 16 : i32 , " ttig.min_sg_size" = 16 : i32 , " ttig.support_dpas" } {
344
+ tt.func @M_smaller_than_8 (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }) {
345
+ // CHECK-LABEL: M_smaller_than_8
346
+ %b = arith.constant dense <0.000000e+00 > : tensor <128 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>>
347
+
348
+ // CHECK: tt.dot {{.*}} -> tensor<1x16xf32, #[[$DPAS0]]>
349
+ %a0 = arith.constant dense <0.000000e+00 > : tensor <1 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>>
350
+ %zero0 = arith.constant dense <0.000000e+00 > : tensor <1 x16 xf32 , #blocked >
351
+ %result0 = tt.dot %a0 , %b , %zero0 : tensor <1 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>> * tensor <128 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <1 x16 xf32 , #blocked >
352
+ %result_ptr0 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <1 x16 x!tt.ptr <f32 >, #blocked >
353
+ tt.store %result_ptr0 , %result0 : tensor <1 x16 x!tt.ptr <f32 >, #blocked >
354
+
355
+ // CHECK: tt.dot {{.*}} -> tensor<2x16xf32, #[[$DPAS1]]>
356
+ %a1 = arith.constant dense <0.000000e+00 > : tensor <2 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>>
357
+ %zero1 = arith.constant dense <0.000000e+00 > : tensor <2 x16 xf32 , #blocked >
358
+ %result1 = tt.dot %a1 , %b , %zero1 : tensor <2 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>> * tensor <128 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <2 x16 xf32 , #blocked >
359
+ %result_ptr1 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <2 x16 x!tt.ptr <f32 >, #blocked >
360
+ tt.store %result_ptr1 , %result1 : tensor <2 x16 x!tt.ptr <f32 >, #blocked >
361
+
362
+ // CHECK: tt.dot {{.*}} -> tensor<4x16xf32, #[[$DPAS2]]>
363
+ %a2 = arith.constant dense <0.000000e+00 > : tensor <4 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>>
364
+ %zero2 = arith.constant dense <0.000000e+00 > : tensor <4 x16 xf32 , #blocked >
365
+ %result2 = tt.dot %a2 , %b , %zero2 : tensor <4 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>> * tensor <128 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <4 x16 xf32 , #blocked >
366
+ %result_ptr2 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <4 x16 x!tt.ptr <f32 >, #blocked >
367
+ tt.store %result_ptr2 , %result2 : tensor <4 x16 x!tt.ptr <f32 >, #blocked >
368
+ tt.return
369
+ }
370
+ }
0 commit comments