@@ -2925,7 +2925,7 @@ tt.func @hoist_multiple_conditional(
29252925 }
29262926 %2 = arith.addf %0 , %1 : tensor <128 x32 xf32 , #blocked >
29272927 %3 = ttg.convert_layout %2 : tensor <128 x32 xf32 , #blocked > -> tensor <128 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
2928- %4 = tt.dot %3 , %arg4 , %arg5 : tensor <128 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x128 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x128 xf32 , #mma >
2928+ %4 = tt.dot %3 , %arg4 , %arg5 , inputPrecision = tf32 : tensor <128 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x128 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x128 xf32 , #mma >
29292929 tt.return %4 : tensor <128 x128 xf32 , #mma >
29302930}
29312931
@@ -2954,7 +2954,7 @@ tt.func @hoist_across_loop(
29542954 }
29552955 // CHECK-NOT: ttg.convert_layout
29562956 %2 = ttg.convert_layout %1 : tensor <128 x32 xf32 , #blocked > -> tensor <128 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
2957- %3 = tt.dot %2 , %arg2 , %acc : tensor <128 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x128 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x128 xf32 , #mma >
2957+ %3 = tt.dot %2 , %arg2 , %acc , inputPrecision = tf32 : tensor <128 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x128 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x128 xf32 , #mma >
29582958 scf.yield %1 , %3 : tensor <128 x32 xf32 , #blocked >, tensor <128 x128 xf32 , #mma >
29592959 }
29602960 tt.return %0#1 : tensor <128 x128 xf32 , #mma >
0 commit comments