diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir index 1a93d1cd9b788..b375fad2ce5d6 100644 --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -807,56 +807,41 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> { +func.func @vectorize_scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> { %c4 = arith.constant 4 : index %c0 = arith.constant 0 : index - %cst = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32> - - %out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) { - ^bb0(%out: i32): - %8 = linalg.index 0 : index - %idx_0 = linalg.index 0 : index - %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32> - linalg.yield %extracted : i32 + %src = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32> + + %res = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + outs(%init : tensor<1x1x4xi32>) { + + ^bb0(%out: i32): + %idx = linalg.index 0 : index + %extracted = tensor.extract %src[%idx, %c0] : tensor<15x1xi32> + linalg.yield %extracted : i32 } -> tensor<1x1x4xi32> - return %out:tensor<1x1x4xi32> + return %res : tensor<1x1x4xi32> } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (0, 0, 0)> -// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> { -// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index -// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32> -// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index -// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32> -// CHECK: %[[VAL_10:.*]] = vector.step : vector<1xindex> -// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex> -// CHECK: %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex> -// CHECK: %[[VAL_13:.*]] = vector.step : vector<1xindex> -// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex> -// CHECK: %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex> -// CHECK: %[[VAL_16:.*]] = arith.constant dense : vector<1x1x4xi1> -// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32> -// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_19]][0] : index from vector<4xindex> -// CHECK: %[[VAL_21:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_22:.*]] = vector.constant_mask [1] : vector<1xi1> -// CHECK: %[[VAL_23:.*]] = vector.mask %[[VAL_22]] { vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_20]], %[[VAL_2]]], %[[VAL_21]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<15x1xi32>, vector<1x1x4xi32> } : vector<1xi1> -> vector<1x1x4xi32> -// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32> -// CHECK: return %[[VAL_25]] : tensor<1x1x4xi32> +// CHECK-LABEL: func.func @vectorize_scalar_read_with_broadcast_from_column_tensor( +// CHECK-SAME: %[[INIT:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> { +// CHECK: %[[PAD:.*]] = arith.constant 0 : i32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32> +// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex> +// CHECK: %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector +// CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector to vector<1x1x4xi32> +// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op transform.yield } }