@@ -63,3 +63,67 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.tar
63
63
tt.return
64
64
}
65
65
}
66
+
67
+ // -----
68
+
69
+ // COM: Dot operand A transpose currently not supported by subgroup 2d block io encoding
70
+ #blocked = #ttg.blocked <{sizePerThread = [4 , 4 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [8 , 4 ], order = [1 , 0 ]}>
71
+ #blocked1 = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [16 , 1 ], warpsPerCTA = [2 , 16 ], order = [0 , 1 ]}>
72
+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [16 , 2 ], order = [1 , 0 ]}>
73
+ // CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}>
74
+ // CHECK: #mma1 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
75
+ // CHECK-NOT: #mma2
76
+ #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
77
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 , ttg.target = " xpu" , " ttg.threads-per-warp" = 16 : i32 , ttig.min_sg_size = 16 : i32 , ttig.support_bf16_conversion , ttig.support_dpas , ttig.support_sg_2d_block , ttig.target_arch = " spir64" } {
78
+ tt.func public @matmul_kernel_with_block_pointers (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg2: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }) attributes {noinline = false } {
79
+ %c4_i32 = arith.constant 4 : i32
80
+ %c256_i32 = arith.constant 256 : i32
81
+ %c1024_i64 = arith.constant 1024 : i64
82
+ %c5120_i64 = arith.constant 5120 : i64
83
+ %c1_i64 = arith.constant 1 : i64
84
+ %c0_i32 = arith.constant 0 : i32
85
+ %c4096_i64 = arith.constant 4096 : i64
86
+ %c32_i32 = arith.constant 32 : i32
87
+ %c64_i32 = arith.constant 64 : i32
88
+ %c5120_i32 = arith.constant 5120 : i32
89
+ %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #blocked >
90
+ %0 = tt.get_program_id x : i32
91
+ %1 = arith.divsi %0 , %c64_i32 : i32
92
+ %2 = arith.muli %1 , %c4_i32 : i32
93
+ %3 = arith.subi %c4_i32 , %2 : i32
94
+ %4 = arith.minsi %3 , %c4_i32 : i32
95
+ %5 = arith.remsi %0 , %4 : i32
96
+ %6 = arith.addi %2 , %5 : i32
97
+ %7 = arith.remsi %0 , %c64_i32 : i32
98
+ %8 = arith.divsi %7 , %4 : i32
99
+ %9 = arith.muli %6 , %c256_i32 : i32
100
+ %10 = tt.make_tensor_ptr %arg0 , [%c1024_i64 , %c5120_i64 ], [%c1_i64 , %c1024_i64 ], [%9 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xf16 , #blocked1 >>
101
+ %11 = arith.muli %8 , %c256_i32 : i32
102
+ %12 = tt.make_tensor_ptr %arg1 , [%c5120_i64 , %c4096_i64 ], [%c4096_i64 , %c1_i64 ], [%c0_i32 , %11 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x256 xf16 , #blocked2 >>
103
+ %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args (%arg4 = %cst , %arg5 = %10 , %arg6 = %12 ) -> (tensor <256 x256 xf32 , #blocked >, !tt.ptr <tensor <256 x32 xf16 , #blocked1 >>, !tt.ptr <tensor <32 x256 xf16 , #blocked2 >>) : i32 {
104
+ // CHECK: {{.*}} = tt.load %arg5 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<256x32xf16, #blocked1>>
105
+ %17 = tt.load %arg5 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " column_major" } : !tt.ptr <tensor <256 x32 xf16 , #blocked1 >>
106
+ // CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #mma>>
107
+ // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked2>
108
+ %18 = tt.load %arg6 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <32 x256 xf16 , #blocked2 >>
109
+ %19 = ttg.convert_layout %17 : tensor <256 x32 xf16 , #blocked1 > -> tensor <256 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>>
110
+ %20 = ttg.convert_layout %18 : tensor <32 x256 xf16 , #blocked2 > -> tensor <32 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>>
111
+ %21 = ttg.convert_layout %arg4 : tensor <256 x256 xf32 , #blocked > -> tensor <256 x256 xf32 , #mma >
112
+ %22 = ttg.convert_layout %19 : tensor <256 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>> -> tensor <256 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
113
+ %23 = ttg.convert_layout %20 : tensor <32 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <32 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
114
+ // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>> -> tensor<256x256xf32, #mma1>
115
+ %24 = tt.dot %22 , %23 , %21 , inputPrecision = tf32 : tensor <256 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> * tensor <32 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <256 x256 xf32 , #mma >
116
+ %25 = ttg.convert_layout %24 : tensor <256 x256 xf32 , #mma > -> tensor <256 x256 xf32 , #blocked >
117
+ // CHECK: tt.advance {{.*}} : <tensor<256x32xf16, #blocked1>>
118
+ %26 = tt.advance %arg5 , [%c0_i32 , %c32_i32 ] : <tensor <256 x32 xf16 , #blocked1 >>
119
+ // CHECK: tt.advance {{.*}} : <tensor<32x256xf16, #mma>>
120
+ %27 = tt.advance %arg6 , [%c32_i32 , %c0_i32 ] : <tensor <32 x256 xf16 , #blocked2 >>
121
+ scf.yield %25 , %26 , %27 : tensor <256 x256 xf32 , #blocked >, !tt.ptr <tensor <256 x32 xf16 , #blocked1 >>, !tt.ptr <tensor <32 x256 xf16 , #blocked2 >>
122
+ }
123
+ %14 = tt.make_tensor_ptr %arg2 , [%c1024_i64 , %c4096_i64 ], [%c4096_i64 , %c1_i64 ], [%9 , %11 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x256 xf16 , #blocked2 >>
124
+ %15 = arith.truncf %13#0 : tensor <256 x256 xf32 , #blocked > to tensor <256 x256 xf16 , #blocked >
125
+ %16 = ttg.convert_layout %15 : tensor <256 x256 xf16 , #blocked > -> tensor <256 x256 xf16 , #blocked2 >
126
+ tt.store %14 , %16 {boundaryCheck = array<i32 : 0 , 1 >} : !tt.ptr <tensor <256 x256 xf16 , #blocked2 >>
127
+ tt.return
128
+ }
129
+ }
0 commit comments