@@ -137,12 +137,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
137
137
#smem = #ttg.shared_memory
138
138
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx950" , " ttg.threads-per-warp" = 64 : i32 } {
139
139
140
- // CHECK-LABEL: reject_chained_dots_empty_mem_cluster
140
+ // CHECK-LABEL: reject_chained_dots_empty_mem_cluster_1
141
141
142
142
// CHECK-NOT: setprio
143
143
// CHECK-NOT: barrier
144
144
145
- tt.func @reject_chained_dots_empty_mem_cluster (%arg0: tensor <64 x16 xf16 , #blocked >, %arg1: tensor <64 x16 x!tt.ptr <f16 >, #blocked >, %arg2: i32 , %arg3: i32 , %arg4: tensor <128 x16 xf32 , #mma >, %arg5: tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #mma }>>, %arg6: i32 , %arg7: tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, %arg8: tensor <128 x16 xf32 , #mma >, %arg9: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg10: tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>, %arg11: i32 , %arg12: i32 , %arg13: tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #mma }>>) -> tensor <128 x16 xf32 , #mma > {
145
+ tt.func @reject_chained_dots_empty_mem_cluster_1 (%arg0: tensor <64 x16 xf16 , #blocked >, %arg1: tensor <64 x16 x!tt.ptr <f16 >, #blocked >, %arg2: i32 , %arg3: i32 , %arg4: tensor <128 x16 xf32 , #mma >, %arg5: tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #mma }>>, %arg6: i32 , %arg7: tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, %arg8: tensor <128 x16 xf32 , #mma >, %arg9: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg10: tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>, %arg11: i32 , %arg12: i32 , %arg13: tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #mma }>>) -> tensor <128 x16 xf32 , #mma > {
146
146
%c1_i32 = arith.constant 1 : i32
147
147
%c0_i32 = arith.constant 0 : i32
148
148
%0 = ttg.local_alloc : () -> !ttg.memdesc <2 x64 x16 xf16 , #shared , #smem , mutable >
@@ -164,3 +164,29 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
164
164
tt.return %5#0 : tensor <128 x16 xf32 , #mma >
165
165
}
166
166
}
167
+
168
+ // -----
169
+
170
+ #blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
171
+ #mma = #ttg.amd_mfma <{version = 4 , warpsPerCTA = [4 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
172
+ #shared = #ttg.swizzled_shared <{vec = 2 , perPhase = 2 , maxPhase = 8 , order = [0 , 1 ]}>
173
+ #smem = #ttg.shared_memory
174
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx950" , " ttg.threads-per-warp" = 64 : i32 } {
175
+
176
+ // CHECK-LABEL: reject_chained_dots_empty_mem_cluster_2
177
+
178
+ // CHECK-NOT: setprio
179
+ // CHECK-NOT: barrier
180
+
181
+ tt.func @reject_chained_dots_empty_mem_cluster_2 (%memdesc1: !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, %memdesc2: !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, %alloc1: !ttg.memdesc <2 x64 x16 xf16 , #shared , #smem , mutable >, %alloc2: !ttg.memdesc <2 x64 x16 xf16 , #shared , #smem , mutable >, %arg0: tensor <64 x16 xf16 , #blocked >, %arg1: tensor <64 x16 x!tt.ptr <f16 >, #blocked >, %arg2: i32 , %arg3: i32 , %arg4: tensor <128 x16 xf32 , #mma >, %arg5: tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #mma }>>, %arg6: i32 , %arg7: tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, %arg8: tensor <128 x16 xf32 , #mma >, %arg9: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg10: tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>, %arg11: i32 , %arg12: i32 , %arg13: tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #mma }>>) -> tensor <128 x16 xf32 , #mma > {
182
+ %5:8 = scf.for %arg14 = %arg3 to %arg2 step %arg3 iter_args (%arg15 = %arg4 , %arg16 = %arg4 , %arg17 = %arg7 , %arg18 = %memdesc1 , %arg19 = %memdesc1 , %arg20 = %memdesc2 , %arg21 = %arg0 , %arg22 = %arg0 ) -> (tensor <128 x16 xf32 , #mma >, tensor <128 x16 xf32 , #mma >, tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, tensor <64 x16 xf16 , #blocked >, tensor <64 x16 xf16 , #blocked >) : i32 {
183
+ %6 = tt.dot %arg10 , %arg17 , %arg15 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x16 xf32 , #mma >
184
+ ttg.local_store %arg22 , %arg20 : tensor <64 x16 xf16 , #blocked > -> !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >
185
+ %11 = ttg.local_load %arg20 : !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 > -> tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
186
+ %13 = tt.load %arg1 : tensor <64 x16 x!tt.ptr <f16 >, #blocked >
187
+ %10 = tt.dot %arg10 , %arg17 , %arg16 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x16 xf32 , #mma >
188
+ scf.yield %10 , %6 , %11 , %arg19 , %arg20 , %arg20 , %13 , %13 : tensor <128 x16 xf32 , #mma >, tensor <128 x16 xf32 , #mma >, tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, tensor <64 x16 xf16 , #blocked >, tensor <64 x16 xf16 , #blocked >
189
+ }
190
+ tt.return %5#0 : tensor <128 x16 xf32 , #mma >
191
+ }
192
+ }
0 commit comments