@@ -137,12 +137,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
137137#smem = #ttg.shared_memory
138138module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx950" , " ttg.threads-per-warp" = 64 : i32 } {
139139
140- // CHECK-LABEL: reject_chained_dots_empty_mem_cluster
140+ // CHECK-LABEL: reject_chained_dots_empty_mem_cluster_1
141141
142142 // CHECK-NOT: setprio
143143 // CHECK-NOT: barrier
144144
145- tt.func @reject_chained_dots_empty_mem_cluster (%arg0: tensor <64 x16 xf16 , #blocked >, %arg1: tensor <64 x16 x!tt.ptr <f16 >, #blocked >, %arg2: i32 , %arg3: i32 , %arg4: tensor <128 x16 xf32 , #mma >, %arg5: tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #mma }>>, %arg6: i32 , %arg7: tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, %arg8: tensor <128 x16 xf32 , #mma >, %arg9: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg10: tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>, %arg11: i32 , %arg12: i32 , %arg13: tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #mma }>>) -> tensor <128 x16 xf32 , #mma > {
145+ tt.func @reject_chained_dots_empty_mem_cluster_1 (%arg0: tensor <64 x16 xf16 , #blocked >, %arg1: tensor <64 x16 x!tt.ptr <f16 >, #blocked >, %arg2: i32 , %arg3: i32 , %arg4: tensor <128 x16 xf32 , #mma >, %arg5: tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #mma }>>, %arg6: i32 , %arg7: tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, %arg8: tensor <128 x16 xf32 , #mma >, %arg9: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg10: tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>, %arg11: i32 , %arg12: i32 , %arg13: tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #mma }>>) -> tensor <128 x16 xf32 , #mma > {
146146 %c1_i32 = arith.constant 1 : i32
147147 %c0_i32 = arith.constant 0 : i32
148148 %0 = ttg.local_alloc : () -> !ttg.memdesc <2 x64 x16 xf16 , #shared , #smem , mutable >
@@ -164,3 +164,29 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
164164 tt.return %5#0 : tensor <128 x16 xf32 , #mma >
165165 }
166166}
167+
168+ // -----
169+
170+ #blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
171+ #mma = #ttg.amd_mfma <{version = 4 , warpsPerCTA = [4 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
172+ #shared = #ttg.swizzled_shared <{vec = 2 , perPhase = 2 , maxPhase = 8 , order = [0 , 1 ]}>
173+ #smem = #ttg.shared_memory
174+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx950" , " ttg.threads-per-warp" = 64 : i32 } {
175+
176+ // CHECK-LABEL: reject_chained_dots_empty_mem_cluster_2
177+
178+ // CHECK-NOT: setprio
179+ // CHECK-NOT: barrier
180+
181+ tt.func @reject_chained_dots_empty_mem_cluster_2 (%memdesc1: !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, %memdesc2: !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, %alloc1: !ttg.memdesc <2 x64 x16 xf16 , #shared , #smem , mutable >, %alloc2: !ttg.memdesc <2 x64 x16 xf16 , #shared , #smem , mutable >, %arg0: tensor <64 x16 xf16 , #blocked >, %arg1: tensor <64 x16 x!tt.ptr <f16 >, #blocked >, %arg2: i32 , %arg3: i32 , %arg4: tensor <128 x16 xf32 , #mma >, %arg5: tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #mma }>>, %arg6: i32 , %arg7: tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, %arg8: tensor <128 x16 xf32 , #mma >, %arg9: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg10: tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>, %arg11: i32 , %arg12: i32 , %arg13: tensor <128 xf32 , #ttg.slice <{dim = 1 , parent = #mma }>>) -> tensor <128 x16 xf32 , #mma > {
182+ %5:8 = scf.for %arg14 = %arg3 to %arg2 step %arg3 iter_args (%arg15 = %arg4 , %arg16 = %arg4 , %arg17 = %arg7 , %arg18 = %memdesc1 , %arg19 = %memdesc1 , %arg20 = %memdesc2 , %arg21 = %arg0 , %arg22 = %arg0 ) -> (tensor <128 x16 xf32 , #mma >, tensor <128 x16 xf32 , #mma >, tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, tensor <64 x16 xf16 , #blocked >, tensor <64 x16 xf16 , #blocked >) : i32 {
183+ %6 = tt.dot %arg10 , %arg17 , %arg15 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x16 xf32 , #mma >
184+ ttg.local_store %arg22 , %arg20 : tensor <64 x16 xf16 , #blocked > -> !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >
185+ %11 = ttg.local_load %arg20 : !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 > -> tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
186+ %13 = tt.load %arg1 : tensor <64 x16 x!tt.ptr <f16 >, #blocked >
187+ %10 = tt.dot %arg10 , %arg17 , %arg16 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x16 xf32 , #mma >
188+ scf.yield %10 , %6 , %11 , %arg19 , %arg20 , %arg20 , %13 , %13 : tensor <128 x16 xf32 , #mma >, tensor <128 x16 xf32 , #mma >, tensor <64 x16 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, !ttg.memdesc <64 x16 xf16 , #shared , #smem , mutable , 2 x64 x16 >, tensor <64 x16 xf16 , #blocked >, tensor <64 x16 xf16 , #blocked >
189+ }
190+ tt.return %5#0 : tensor <128 x16 xf32 , #mma >
191+ }
192+ }
0 commit comments