@@ -333,3 +333,38 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
333333 tt.return
334334 }
335335}
336+
337+ // -----
338+
339+ // CHECK: #[[$DPAS0:.+]] = #ttig.dpas<{repeatCount = 1, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [1, 16], B = [16, 16], C = [1, 16]}>
340+ // CHECK: #[[$DPAS1:.+]] = #ttig.dpas<{repeatCount = 2, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [2, 16], B = [16, 16], C = [2, 16]}>
341+ // CHECK: #[[$DPAS2:.+]] = #ttig.dpas<{repeatCount = 4, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [4, 16], B = [16, 16], C = [4, 16]}>
342+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
343+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 1 : i32 , " ttg.threads-per-warp" = 16 : i32 , " ttig.min_sg_size" = 16 : i32 , " ttig.support_dpas" } {
344+ tt.func @M_smaller_than_8 (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }) {
345+ // CHECK-LABEL: M_smaller_than_8
346+ %b = arith.constant dense <0.000000e+00 > : tensor <128 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>>
347+
348+ // CHECK: tt.dot {{.*}} -> tensor<1x16xf32, #[[$DPAS0]]>
349+ %a0 = arith.constant dense <0.000000e+00 > : tensor <1 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>>
350+ %zero0 = arith.constant dense <0.000000e+00 > : tensor <1 x16 xf32 , #blocked >
351+ %result0 = tt.dot %a0 , %b , %zero0 : tensor <1 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>> * tensor <128 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <1 x16 xf32 , #blocked >
352+ %result_ptr0 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <1 x16 x!tt.ptr <f32 >, #blocked >
353+ tt.store %result_ptr0 , %result0 : tensor <1 x16 x!tt.ptr <f32 >, #blocked >
354+
355+ // CHECK: tt.dot {{.*}} -> tensor<2x16xf32, #[[$DPAS1]]>
356+ %a1 = arith.constant dense <0.000000e+00 > : tensor <2 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>>
357+ %zero1 = arith.constant dense <0.000000e+00 > : tensor <2 x16 xf32 , #blocked >
358+ %result1 = tt.dot %a1 , %b , %zero1 : tensor <2 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>> * tensor <128 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <2 x16 xf32 , #blocked >
359+ %result_ptr1 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <2 x16 x!tt.ptr <f32 >, #blocked >
360+ tt.store %result_ptr1 , %result1 : tensor <2 x16 x!tt.ptr <f32 >, #blocked >
361+
362+ // CHECK: tt.dot {{.*}} -> tensor<4x16xf32, #[[$DPAS2]]>
363+ %a2 = arith.constant dense <0.000000e+00 > : tensor <4 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>>
364+ %zero2 = arith.constant dense <0.000000e+00 > : tensor <4 x16 xf32 , #blocked >
365+ %result2 = tt.dot %a2 , %b , %zero2 : tensor <4 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>> * tensor <128 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <4 x16 xf32 , #blocked >
366+ %result_ptr2 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <4 x16 x!tt.ptr <f32 >, #blocked >
367+ tt.store %result_ptr2 , %result2 : tensor <4 x16 x!tt.ptr <f32 >, #blocked >
368+ tt.return
369+ }
370+ }
0 commit comments