@@ -106,24 +106,28 @@ module {
106106// CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32
107107// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
108108// CHECK-DAG: [[CST_1_i32:%.+]] = arith.constant 1 : i32
109- // CHECK: [[VAR_28_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<128x64xbf16>>
110- // CHECK: [[VAR_31_:%.+]] = arith.index_cast [[PARAM_9_]] : i32 to index
111- // CHECK: [[VAR_32_:%.+]] = arith.index_cast [[VAR_31_]] : index to i64
112- // CHECK: [[VAR_38_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<64x256xbf16>>
113- // CHECK-DAG: [[VAR_39_:%.+]] = arith.muli [[PARAM_7_]], [[CST_64_i32]] : i32
114- // CHECK-DAG: [[VAR_40_:%.+]] = arith.muli [[PARAM_8_]], [[CST_64_i32]] : i32
115- // CHECK: [[VAR_41_:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg13_:%.+]] = [[VAR_cst_]], [[VAR_arg14_:%.+]] = [[VAR_28_]], [[VAR_arg15_:%.+]] = [[VAR_38_]]) -> (tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>) : i32 {
116- // CHECK-DAG: [[VAR_54_:%.+]] = tt.load [[VAR_arg14_]] : !tt.ptr<tensor<128x64xbf16>>
117- // CHECK-DAG: [[VAR_55_:%.+]] = tt.load [[VAR_arg15_]] : !tt.ptr<tensor<64x256xbf16>>
118- // CHECK: [[VAR_56_:%.+]] = tt.dot [[VAR_54_]], [[VAR_55_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32>
119- // CHECK-DAG: [[VAR_57_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_56_]] : tensor<128x256xf32>
120- // CHECK-DAG: [[VAR_58_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[CST_0_i32]], [[VAR_39_]]] : <tensor<128x64xbf16>>
121- // CHECK-DAG: [[VAR_59_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[CST_0_i32]], [[VAR_40_]]] : <tensor<64x256xbf16>>
122- // CHECK: scf.yield [[VAR_57_]], [[VAR_58_]], [[VAR_59_]] : tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>
109+ // CHECK: [[VAR_20_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64
110+ // CHECK: [[VAR_21_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64
111+ // CHECK: [[VAR_22_:%.+]] = arith.divui {{.*}}, [[PARAM_6_]] : i32
112+ // CHECK: [[VAR_23_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_20_]], [[VAR_21_]]], {{\[}}[[VAR_22_]], [[CST_0_i32]]] {{.*}} : <tensor<128x64xbf16>>
113+ // CHECK: [[VAR_24_:%.+]] = arith.extsi [[PARAM_8_]] : i32 to i64
114+ // CHECK: [[VAR_25_:%.+]] = arith.muli {{.*}}, [[PARAM_9_]] : i32
115+ // CHECK: [[VAR_26_:%.+]] = arith.extsi [[PARAM_9_]] : i32 to i64
116+ // CHECK: [[VAR_27_:%.+]] = arith.divui [[VAR_25_]], [[PARAM_9_]] : i32
117+ // CHECK: [[VAR_28_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_24_]], [[VAR_26_]]], {{\[}}[[CST_0_i32]], [[VAR_27_]]] {{.*}} : <tensor<64x256xbf16>>
118+ // CHECK-DAG: [[VAR_29_:%.+]] = arith.muli [[PARAM_7_]], [[CST_64_i32]] : i32
119+ // CHECK-DAG: [[VAR_30_:%.+]] = arith.muli [[PARAM_8_]], [[CST_64_i32]] : i32
120+ // CHECK: [[VAR_31_:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg13_:%.+]] = [[VAR_cst_]], [[VAR_arg14_:%.+]] = [[VAR_23_]], [[VAR_arg15_:%.+]] = [[VAR_28_]]) -> (tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>) : i32 {
121+ // CHECK-DAG: [[VAR_40_:%.+]] = tt.load [[VAR_arg14_]] : !tt.ptr<tensor<128x64xbf16>>
122+ // CHECK-DAG: [[VAR_41_:%.+]] = tt.load [[VAR_arg15_]] : !tt.ptr<tensor<64x256xbf16>>
123+ // CHECK: [[VAR_42_:%.+]] = tt.dot [[VAR_40_]], [[VAR_41_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32>
124+ // CHECK-DAG: [[VAR_43_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_42_]] : tensor<128x256xf32>
125+ // CHECK-DAG: [[VAR_44_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[CST_0_i32]], [[VAR_29_]]] : <tensor<128x64xbf16>>
126+ // CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : <tensor<64x256xbf16>>
127+ // CHECK: scf.yield [[VAR_43_]], [[VAR_44_]], [[VAR_45_]] : tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>
123128// CHECK: }
124- // CHECK-DAG: [[VAR_42_:%.+]] = arith.truncf [[VAR_41_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16>
125- // CHECK-DAG: [[VAR_43_:%.+]] = arith.index_cast [[PARAM_10_]] : i32 to index
126- // CHECK: [[VAR_53_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<128x256xbf16>>
127- // CHECK: tt.store [[VAR_53_]], [[VAR_42_]] : !tt.ptr<tensor<128x256xbf16>>
129+ // CHECK-DAG: [[VAR_32_:%.+]] = arith.truncf [[VAR_31_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16>
130+ // CHECK: [[VAR_39_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<128x256xbf16>>
131+ // CHECK: tt.store [[VAR_39_]], [[VAR_32_]] : !tt.ptr<tensor<128x256xbf16>>
128132// CHECK: tt.return
129133// CHECK: }
0 commit comments