@@ -41,6 +41,8 @@ module {
4141// CHECK-LABEL: func.func @kernel
4242// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) {
4343// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : i32
44+ // CHECK-DAG: [[CST_512_:%.+]] = arith.constant 512 : i32
45+ // CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : i32
4446// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<256x128xi32>
4547// CHECK-NOT: separator of consecutive DAGs
4648// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_6_]] : i32) outs([[VAR_0_]] : tensor<256x128xi32>) -> tensor<256x128xi32>
@@ -49,7 +51,8 @@ module {
4951// CHECK: ^bb0([[IN_0_:%.+]]: i32):
5052// CHECK: [[VAR_13_:%.+]] = linalg.index 0 : index
5153// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i32
52- // CHECK: linalg.yield [[VAR_14_]] : i32
54+ // CHECK: [[VAL_24:%.+]] = arith.addi [[VAR_14_]], [[CST_512_]] : i32
55+ // CHECK: linalg.yield [[VAL_24]] : i32
5356// CHECK: } -> tensor<256xi32>
5457// CHECK: [[VAR_expanded_:%.+]] = tensor.expand_shape [[VAR_3_]] {{.}}[0, 1]{{.}} output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32>
5558// CHECK: [[VAR_4_:%.+]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_expanded_]] : tensor<256x1xi32>) outs([[VAR_0_]] : tensor<256x128xi32>) attrs = {broadcastDims = array<i64: 1>} {
@@ -61,7 +64,8 @@ module {
6164// CHECK: ^bb0([[IN_3_:%.+]]: i32):
6265// CHECK: [[VAR_13_1_:%.+]] = linalg.index 0 : index
6366// CHECK: [[VAR_14_1_:%.+]] = arith.index_cast [[VAR_13_1_]] : index to i32
64- // CHECK: linalg.yield [[VAR_14_1_]] : i32
67+ // CHECK: [[VAL_25:%.+]] = arith.addi [[VAR_14_1_]], [[CST_1024_]] : i32
68+ // CHECK: linalg.yield [[VAL_25]] : i32
6569// CHECK: } -> tensor<128xi32>
6670// CHECK: [[VAR_expanded_0_:%.+]] = tensor.expand_shape [[VAR_6_]] {{.}}[0, 1]{{.}} output_shape [1, 128] : tensor<128xi32> into tensor<1x128xi32>
6771// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map3, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_expanded_0_]] : tensor<1x128xi32>) outs([[VAR_0_]] : tensor<256x128xi32>) attrs = {broadcastDims = array<i64: 0>} {
0 commit comments